diff --git a/CMakeLists.txt b/CMakeLists.txt index 81eeecc7..1e868b83 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -20,7 +20,7 @@ project(xfastertransformer LANGUAGES C CXX) option(WITH_GPU "Build with GPU" OFF) if(WITH_GPU) message(STATUS "Notice: Building with GPU.") - add_definitions(-DGPU=true) + add_definitions(-DXFT_GPU=true) # Get compiler version execute_process(COMMAND ${CMAKE_CXX_COMPILER} --version OUTPUT_VARIABLE ICPX_VERSION @@ -35,10 +35,6 @@ else() message(STATUS "Notice: GCC version: ${GCC_VERSION}") endif() -if(NOT CMAKE_BUILD_TYPE) - set(CMAKE_BUILD_TYPE Release) -endif() - set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mavx512f -mavx512bw -mavx512vl -fPIC") if(WITH_GPU) @@ -73,11 +69,15 @@ if(GCC_VERSION VERSION_GREATER_EQUAL "10.1") endif() endif() +if(NOT CMAKE_BUILD_TYPE) + set(CMAKE_BUILD_TYPE Release) +endif() + if(CMAKE_BUILD_TYPE MATCHES "Debug") message(STATUS "Notice: Using Debug mode.") set(CMAKE_C_FLAGS "-O0 -g") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O0 -g") - add_definitions(-DDEBUG=true) + add_definitions(-DXFT_DEBUG=true) add_definitions(-DSTEP_BY_STEP_ATTN=true) else() message(STATUS "Notice: Using Release mode.") diff --git a/requirements-gpu.txt b/requirements-gpu.txt new file mode 100644 index 00000000..f699e100 --- /dev/null +++ b/requirements-gpu.txt @@ -0,0 +1,8 @@ +-f https://download.pytorch.org/whl/torch_stable.html +cmake==3.26.1 +sentencepiece==0.1.99 +torch==2.3.0+cpu.cxx11.abi +transformers==4.40.0 +accelerate==0.23.0 +protobuf +tiktoken diff --git a/src/common/allocator.h b/src/common/allocator.h index 2e454772..6fc80f80 100644 --- a/src/common/allocator.h +++ b/src/common/allocator.h @@ -15,8 +15,13 @@ #pragma once #include #include -#include +#include #include "environment.h" +#include + +#ifdef XFT_GPU +#include +#endif namespace xft { @@ -26,10 +31,22 @@ static inline bool is_thp_alloc(size_t nbytes) { return (Env::getInstance().getTHPEnabled() && (nbytes >= g_thp_threshold)); } -static inline void *alloc(size_t nbytes, size_t alignment = 64) { +static inline void *alloc(size_t nbytes, void *device = nullptr, size_t alignment = 64) { if (nbytes == 0) { return nullptr; } - void *data; + void *data = nullptr; + +#ifdef XFT_GPU + if (device != nullptr) { + sycl::queue *gpu_queue = static_cast(device); + data = sycl::malloc_device(nbytes, *gpu_queue); + if (data == nullptr) { + printf("Unable to allocate buffer with size of %zu in GPU.\n", nbytes); + exit(-1); + } + return data; + } +#endif int err = posix_memalign(&data, alignment, nbytes); if (err != 0) { @@ -47,4 +64,40 @@ static inline void *alloc(size_t nbytes, size_t alignment = 64) { return data; } + +static inline void dealloc(void *data, void *device = nullptr) { +#ifdef XFT_GPU + if (device != nullptr) { + sycl::free(data, *static_cast(device)); + return; + } +#endif + + free(data); +} + +static inline void memcopy(void *dst, const void *src, size_t size, void *device = nullptr) { +#ifdef XFT_GPU + if (device != nullptr) { + sycl::queue *gpu_queue = static_cast(device); + gpu_queue->memcpy(dst, src, size).wait(); + return; + } +#endif + + memcpy(dst, src, size); +} + +static inline void memsetv(void *dst, int ch, size_t size, void *device = nullptr) { +#ifdef XFT_GPU + if (device != nullptr) { + sycl::queue *gpu_queue = static_cast(device); + gpu_queue->memset(dst, ch, size).wait(); + return; + } +#endif + + memset(dst, ch, size); +} + } // namespace xft \ No newline at end of file diff --git a/src/common/sequence.h b/src/common/sequence.h index 211b69ed..6376e4a6 100644 --- a/src/common/sequence.h +++ b/src/common/sequence.h @@ -19,6 +19,7 @@ #include #include +#include "allocator.h" #include "environment.h" #include "sampling_params.h" @@ -67,7 +68,7 @@ class SequenceIDManager { // The SequenceMeta is one sequence of batch inputs and includes the generated tokens. class SequenceMeta { public: - SequenceMeta(std::vector &_promptTokens) + SequenceMeta(const std::vector &_promptTokens) : sequenceID(SequenceIDManager::getInstance().createSequenceID()) , inputSeqLen(_promptTokens.size()) , pastSeqLen(0) @@ -81,6 +82,16 @@ class SequenceMeta { , promptTokens(_inputSeqLen, 0) , step(0) {} + SequenceMeta(int32_t _sequenceID, const std::vector &_promptTokens) + : sequenceID(_sequenceID) + , inputSeqLen(_promptTokens.size()) + , pastSeqLen(0) + , promptTokens(_promptTokens) + , step(0) {} + + SequenceMeta(int32_t _sequenceID, int32_t _inputSeqLen) + : sequenceID(_sequenceID), inputSeqLen(_inputSeqLen), pastSeqLen(0), promptTokens(_inputSeqLen, 0), step(0) {} + ~SequenceMeta() {} int32_t getSequenceID() const { return sequenceID; } @@ -175,7 +186,8 @@ class SequenceGroupMeta { groupID = sequences[0].getSequenceID(); } - SequenceGroupMeta(std::vector &_inputTokens, SamplingMeta &samplingMeta_) : samplingMeta(samplingMeta_) { + SequenceGroupMeta(const std::vector &_inputTokens, SamplingMeta &samplingMeta_) + : samplingMeta(samplingMeta_) { sequences.reserve(samplingMeta.config.numBeams); for (int i = 0; i < samplingMeta.config.numBeams; ++i) { sequences.emplace_back(SequenceMeta(_inputTokens)); @@ -191,7 +203,7 @@ class SequenceGroupMeta { groupID = sequences[0].getSequenceID(); } - SequenceGroupMeta(std::vector &_inputTokens) { + SequenceGroupMeta(const std::vector &_inputTokens) { sequences.reserve(samplingMeta.config.numBeams); for (int i = 0; i < samplingMeta.config.numBeams; ++i) { sequences.emplace_back(SequenceMeta(_inputTokens)); @@ -207,6 +219,40 @@ class SequenceGroupMeta { groupID = sequences[0].getSequenceID(); } + SequenceGroupMeta(int32_t _sequenceID, const std::vector &_inputTokens, SamplingMeta &samplingMeta_) + : samplingMeta(samplingMeta_) { + sequences.reserve(samplingMeta.config.numBeams); + for (int i = 0; i < samplingMeta.config.numBeams; ++i) { + sequences.emplace_back(SequenceMeta(_sequenceID, _inputTokens)); + } + groupID = sequences[0].getSequenceID(); + } + + SequenceGroupMeta(int32_t _sequenceID, int32_t _inputSeqLen, SamplingMeta &samplingMeta_) + : samplingMeta(samplingMeta_) { + sequences.reserve(samplingMeta.config.numBeams); + for (int i = 0; i < samplingMeta.config.numBeams; ++i) { + sequences.emplace_back(SequenceMeta(_sequenceID, _inputSeqLen)); + } + groupID = sequences[0].getSequenceID(); + } + + SequenceGroupMeta(int32_t _sequenceID, const std::vector &_inputTokens) { + sequences.reserve(samplingMeta.config.numBeams); + for (int i = 0; i < samplingMeta.config.numBeams; ++i) { + sequences.emplace_back(SequenceMeta(_sequenceID, _inputTokens)); + } + groupID = sequences[0].getSequenceID(); + } + + SequenceGroupMeta(int32_t _sequenceID, int32_t _inputSeqLen) { + sequences.reserve(samplingMeta.config.numBeams); + for (int i = 0; i < samplingMeta.config.numBeams; ++i) { + sequences.emplace_back(SequenceMeta(_sequenceID, _inputSeqLen)); + } + groupID = sequences[0].getSequenceID(); + } + int32_t getGroupID() { return groupID; } int32_t getGroupSize() { return samplingMeta.config.numBeams; } @@ -272,6 +318,31 @@ class SequencePool { return group; } + SequenceGroupMeta *newGroupMeta( + int32_t sequenceID, std::vector &inputTokens, SamplingMeta &samplingMeta_) { + auto *group = new SequenceGroupMeta(sequenceID, inputTokens, samplingMeta_); + this->add(group); + return group; + } + + SequenceGroupMeta *newGroupMeta(int32_t sequenceID, int32_t inputSeqLen, SamplingMeta &samplingMeta_) { + auto *group = new SequenceGroupMeta(sequenceID, inputSeqLen, samplingMeta_); + this->add(group); + return group; + } + + SequenceGroupMeta *newGroupMeta(int32_t sequenceID, std::vector &inputTokens) { + auto *group = new SequenceGroupMeta(sequenceID, inputTokens); + this->add(group); + return group; + } + + SequenceGroupMeta *newGroupMeta(int32_t sequenceID, int32_t inputSeqLen) { + auto *group = new SequenceGroupMeta(sequenceID, inputSeqLen); + this->add(group); + return group; + } + bool add(SequenceGroupMeta *sequenceGroup, bool force = false) { int32_t groupID = sequenceGroup->getGroupID(); bool isSuccess = false; diff --git a/src/common/transformer_ctx.h b/src/common/transformer_ctx.h index 0faa075d..27b777bc 100644 --- a/src/common/transformer_ctx.h +++ b/src/common/transformer_ctx.h @@ -112,7 +112,8 @@ struct DecoderContext { xft::Matrix qkvMatMul; // query, key, value xft::Matrix imOut; // intermediate output - MMHelper *mmHelper; + MMHelper *mmHelper = nullptr; + void *device = nullptr; std::string configPath; INIReader configReader; @@ -130,7 +131,7 @@ struct DecoderContext { public: DecoderContext(int _layers, int _hiddenSize, int _headSize, int _attHeadNum, int _kvHeadNum, int _imSize, const std::string &act, float epsilon, int _vocabSize, int _embeddingSize, int _maxPositions, int _maxPosEmbed, int _maxSeqLength, - int _splitIdx, int _splits, int _ppSize = 1, int _ppRank = 0, RopeParams *_ropeParamsPtr = nullptr, + int _splitIdx, int _splits, MMHelper *mmHelper, void *device = nullptr, int _ppSize = 1, int _ppRank = 0, RopeParams *_ropeParamsPtr = nullptr, bool _useLogN = true, bool _useNTK = true, int numThreads = 0) : layers(_layers) , hiddenSize(_hiddenSize) @@ -170,9 +171,12 @@ struct DecoderContext { } } + this->mmHelper = mmHelper; + this->device = device; + this->rawBufSize = 4 * 32 * intermediateSize + 4 * attHeadNum * 32 * 32; // assume bs=4, seq=32 - this->rawBuffer = (float *)xft::alloc(sizeof(float) * rawBufSize); - memset(this->rawBuffer, 0, sizeof(float) * rawBufSize); + this->rawBuffer = (float *)xft::alloc(sizeof(float) * rawBufSize, this->device); + xft::memsetv(this->rawBuffer, 0, sizeof(float) * rawBufSize, this->device); if (act == "relu") { this->actType = RELU; @@ -240,8 +244,12 @@ struct DecoderContext { bool cached(const std::string &name) { return SimpleMemPool::instance().cached(name); } template - T *getBuffer(const std::string &name, size_t size, size_t alignment = 64) { - return (T *)SimpleMemPool::instance().getBuffer(name, sizeof(T) * size, alignment); + T *getBuffer(const std::string &name, size_t size, void *device = nullptr, size_t alignment = 64) { + return (T *)SimpleMemPool::instance().getBuffer(name, sizeof(T) * size, device, alignment); + } + + void freeBuffer(const std::string &name) { + SimpleMemPool::instance().freeBuffer(name); } void dump() { @@ -286,10 +294,10 @@ struct DecoderContext { uint64_t total = size1 + size2 + size3; if (total > this->rawBufSize) { this->rawBufSize = total; - free(this->rawBuffer); + if (this->rawBuffer) xft::dealloc(this->rawBuffer, this->device); - this->rawBuffer = (float *)xft::alloc(sizeof(float) * rawBufSize); - memset(this->rawBuffer, 0, sizeof(float) * rawBufSize); + this->rawBuffer = (float *)xft::alloc(sizeof(float) * rawBufSize, this->device); + xft::memsetv(this->rawBuffer, 0, sizeof(float) * rawBufSize, this->device); } // Assign the buffer @@ -312,5 +320,9 @@ struct DecoderContext { return rawBufSize - size1 - size2; } - ~DecoderContext() { free(this->rawBuffer); } + ~DecoderContext() { +#ifndef XFT_GPU + if (this->rawBuffer) xft::dealloc(this->rawBuffer, this->device); +#endif + } }; \ No newline at end of file diff --git a/src/kernels/attention_kernels.cpp b/src/kernels/attention_kernels.cpp index 2ace33d7..497c99ed 100644 --- a/src/kernels/attention_kernels.cpp +++ b/src/kernels/attention_kernels.cpp @@ -66,7 +66,7 @@ void crossAttention(bfloat16_t *output, bfloat16_t *query, bfloat16_t *key, bflo small_sgemm_bf16bf16f32_b(true, m, n, k, (XDNN_BF16 *)A, lda, (XDNN_BF16 *)baseB, ldb, C, ldc, blkIndices, cacheBlkStride, cacheBlkSize); -#ifdef DEBUG +#ifdef XFT_DEBUG if (b == 0 && i == 0) { printf("Q * K, first head:\n"); auto p = C; @@ -78,7 +78,7 @@ void crossAttention(bfloat16_t *output, bfloat16_t *query, bfloat16_t *key, bflo // Softmax(Q * K) small_softmax_f32(C, scale, n); -#ifdef DEBUG +#ifdef XFT_DEBUG if (b == 0 && i == 0) { printf("Softmax(Q * K), first head:\n"); auto p = C; @@ -100,7 +100,7 @@ void crossAttention(bfloat16_t *output, bfloat16_t *query, bfloat16_t *key, bflo small_sgemm_f32bf16bf16_b(false, m, n, k, C, lda, (XDNN_BF16 *)baseB, ldb, (XDNN_BF16 *)baseC, ldc, blkIndices, cacheBlkStride, cacheBlkSize); -#ifdef DEBUG +#ifdef XFT_DEBUG if (b == 0 && i == 0) { printf("Softmax(Q * K) * V, first head:\n"); auto p = C; diff --git a/src/kernels/attention_kernels.h b/src/kernels/attention_kernels.h index ca0dac8f..1a9d96bf 100644 --- a/src/kernels/attention_kernels.h +++ b/src/kernels/attention_kernels.h @@ -237,7 +237,7 @@ void selfAttention_SeparateCopy(bfloat16_t *output, bfloat16_t *query, bfloat16_ xdnn_small_amx_sgemm_bf16bf16bf16_compute( m, endSeq, k, (XDNN_BF16 *)A, lda, (XDNN_BF16 *)packedB, (XDNN_BF16 *)C, ldc); -#ifdef DEBUG +#ifdef XFT_DEBUG if (b == 0 && i == 0) { auto B = key + offsets[b] * kvStride + kvHeadIdx * headSize; printf("mnk=%d,%d,%d, ldabc=%d,%d,%d, A[0]=%f, B[0]=%f, packedB[0]=%f\n", m, n, k, lda, ldb, ldc, @@ -260,7 +260,7 @@ void selfAttention_SeparateCopy(bfloat16_t *output, bfloat16_t *query, bfloat16_ memset(C + seq * ldc + elements, 0, (tokens - elements) * sizeof(bfloat16_t)); } -#ifdef DEBUG +#ifdef XFT_DEBUG if (b == 0 && i == 0) { printf("Softmax(Q * Kᵀ), first head:\n"); auto p = C; @@ -290,7 +290,7 @@ void selfAttention_SeparateCopy(bfloat16_t *output, bfloat16_t *query, bfloat16_ xdnn_small_amx_sgemm_bf16bf16bf16_compute( m, n, k, (XDNN_BF16 *)A, lda, (XDNN_BF16 *)packedV, (XDNN_BF16 *)C, ldc); -#ifdef DEBUG +#ifdef XFT_DEBUG if (b == 0 && i == 0) { printf("Softmax(Q * Kᵀ) * V, first head:\n"); auto p = C; @@ -306,7 +306,7 @@ void selfAttention_FusedCopy(bfloat16_t *output, bfloat16_t *query, bfloat16_t * int kvHeadNum, int headSize, int oStride, int qStride, int kvStride, int batchSize, const int *tokenSizes, const float scale, const float *alibiSlopes, int threadNum, const Lambda1 &getKCache, const Lambda2 &getVCache) { -#ifdef DEBUG +#ifdef XFT_DEBUG printf("Q[0]=%f, K[0]=%f, V[0]=%f\n", (float)query[0], (float)key[0], (float)value[0]); printf("kvHeadNum=%d, headSize=%d, qStride=%d, kvStride=%d, batchSize=%d\n", kvHeadNum, headSize, qStride, kvStride, batchSize); @@ -337,7 +337,7 @@ void selfAttention_FusedCopy(bfloat16_t *output, bfloat16_t *query, bfloat16_t * bfloat16_t *scores = (bfloat16_t *)SimpleMemPool::instance().getBuffer( "qkscore", threadNum * mBlockSize * maxScoreStride * sizeof(bfloat16_t)); -#ifdef DEBUG +#ifdef XFT_DEBUG printf("maxTokenSize=%d, tokenSizes[0]=%d, offsets[0]=%d, kvStride=%d\n", maxTokenSize, tokenSizes[0], offsets[0], kvStride); #endif @@ -389,7 +389,7 @@ void selfAttention_FusedCopy(bfloat16_t *output, bfloat16_t *query, bfloat16_t * xdnn_small_amx_sgemm_bf16bf16bf16_compute( m, n, k, (XDNN_BF16 *)A, lda, (XDNN_BF16 *)packedB, (XDNN_BF16 *)C, ldc); -#ifdef DEBUG +#ifdef XFT_DEBUG if (b == 0 && i == 0) { printf("mnk=%d,%d,%d, ldabc=%d,%d,%d, A[0]=%f, B[0]=%f, packedB[0]=%f\n", m, n, k, lda, ldb, ldc, (float)A[0], (float)B[0], (float)packedB[0]); @@ -411,7 +411,7 @@ void selfAttention_FusedCopy(bfloat16_t *output, bfloat16_t *query, bfloat16_t * memset(C + seq * ldc + elements, 0, (tokens - elements) * sizeof(bfloat16_t)); } -#ifdef DEBUG +#ifdef XFT_DEBUG if (b == 0 && i == 0) { printf("Softmax(Q * Kᵀ), first head:\n"); auto p = C; @@ -430,7 +430,7 @@ void selfAttention_FusedCopy(bfloat16_t *output, bfloat16_t *query, bfloat16_t * xdnn_small_amx_sgemm_bf16bf16bf16_compute( m, n, k, (XDNN_BF16 *)A, lda, (XDNN_BF16 *)packedV, (XDNN_BF16 *)C, ldc); -#ifdef DEBUG +#ifdef XFT_DEBUG if (b == 0 && i == 0) { printf("Softmax(Q * Kᵀ) * V, first head:\n"); auto p = C; diff --git a/src/kernels/rotary_embedding_kernels.cpp b/src/kernels/rotary_embedding_kernels.cpp index 566a6340..9812bd35 100644 --- a/src/kernels/rotary_embedding_kernels.cpp +++ b/src/kernels/rotary_embedding_kernels.cpp @@ -450,4 +450,75 @@ void qwenApplyRotaryPosEmbeding(float16_t *query, float16_t *key, int qStride, i maxSupportedSeqLength, qkShape, positionIds); } +#ifdef XFT_GPU +// For LLaMA +template +static inline void llamaApplyRotaryPosEmbeding(void *device, T *query, T *key, int qStride, int kStride, + const float *emb_cos, const float *emb_sin, int inv_freq_size, const int *qkShape, const int *positionIds) { + int dim = inv_freq_size * 2; + REQUIRES(dim == qkShape[3], "Incorrect shape, this dimention is not the head size."); + + const int batchSize = qkShape[0]; + const int seqLen = qkShape[1]; + const int qHeads = qkShape[2]; + const int kHeads = qkShape[4]; + const int head_num = std::max(qHeads, kHeads); + const int head_size = qkShape[3]; + const int half_head_size = (head_size + 1) / 2; + using namespace sycl; + + // Reorder input + sycl::queue *gpu_queue = static_cast(device); + sycl::buffer positionIdsBuf(positionIds, sycl::range<1>(seqLen)); + gpu_queue->submit([&](sycl::handler &cgh) { + sycl::accessor position(positionIdsBuf, cgh, sycl::read_only); + sycl::range<3> globalSize(batchSize * seqLen, head_num, half_head_size); + sycl::range<3> workGroupSize(1, 1, 1); + + cgh.parallel_for(sycl::nd_range(globalSize, workGroupSize), [=](sycl::nd_item<3> item) { + size_t idx_bs_seq = item.get_global_id(0); + size_t idx_head_num = item.get_global_id(1); + size_t idx_half_head_dim = item.get_global_id(2); + + size_t pos = position[idx_bs_seq % seqLen]; + const sycl::half cos = (sycl::half)emb_cos[pos * half_head_size + idx_half_head_dim]; + const sycl::half sin = (sycl::half)emb_sin[pos * half_head_size + idx_half_head_dim]; + + sycl::half *q = (sycl::half *)query + idx_bs_seq * qStride + idx_head_num * head_size + idx_half_head_dim; + sycl::half *k = (sycl::half *)key + idx_bs_seq * kStride + idx_head_num * head_size + idx_half_head_dim; + + if (idx_head_num < qHeads) { + auto q1 = q[0]; + q[0] = q1 * cos - q[half_head_size] * sin; + q[half_head_size] = q[half_head_size] * cos + q1 * sin; + } + if (idx_head_num < kHeads) { + auto k1 = k[0]; + k[0] = k1 * cos - k[half_head_size] * sin; + k[half_head_size] = k[half_head_size] * cos + k1 * sin; + } + }); + }); + gpu_queue->wait(); +} + +void llamaApplyRotaryPosEmbeding(void *device, float *query, float *key, int qStride, int kStride, float *emb_cos, + float *emb_sin, int inv_freq_size, const int *qkShape, const int *positionIds) { + llamaApplyRotaryPosEmbeding( + device, query, key, qStride, kStride, emb_cos, emb_sin, inv_freq_size, qkShape, positionIds); +} + +void llamaApplyRotaryPosEmbeding(void *device, bfloat16_t *query, bfloat16_t *key, int qStride, int kStride, + float *emb_cos, float *emb_sin, int inv_freq_size, const int *qkShape, const int *positionIds) { + llamaApplyRotaryPosEmbeding( + device, query, key, qStride, kStride, emb_cos, emb_sin, inv_freq_size, qkShape, positionIds); +} + +void llamaApplyRotaryPosEmbeding(void *device, float16_t *query, float16_t *key, int qStride, int kStride, + float *emb_cos, float *emb_sin, int inv_freq_size, const int *qkShape, const int *positionIds) { + llamaApplyRotaryPosEmbeding(device, (sycl::half *)query, (sycl::half *)key, qStride, kStride, emb_cos, + emb_sin, inv_freq_size, qkShape, positionIds); +} +#endif + } // namespace xft diff --git a/src/kernels/rotary_embedding_kernels.h b/src/kernels/rotary_embedding_kernels.h index 21d13475..8e782bfe 100644 --- a/src/kernels/rotary_embedding_kernels.h +++ b/src/kernels/rotary_embedding_kernels.h @@ -32,7 +32,7 @@ void llamaApplyRotaryPosEmbeding(bfloat16_t *query, bfloat16_t *key, int qStride void llamaApplyRotaryPosEmbeding(float16_t *query, float16_t *key, int qStride, int kStride, float *emb_cos, float *emb_sin, int inv_freq_size, const int *qkShape, const int *positionIds); -// For continous batching +// For LLaMA continous batching void llamaApplyRotaryPosEmbed(float *query, float *key, float *embCos, float *embSin, int qStride, int kStride, int dim, int totSeqLen, int qHeads, int kHeads, const int *positionIds); @@ -74,4 +74,16 @@ void qwenApplyRotaryPosEmbeding(float16_t *query, float16_t *key, int qStride, i float *cur_emb_sin, int inv_freq_size, const float *logn, int maxSupportedSeqLength, const int *qkShape, const int *positionIds); +#ifdef XFT_GPU +// For LLaMA +void llamaApplyRotaryPosEmbeding(void *device, float *query, float *key, int qStride, int kStride, float *emb_cos, + float *emb_sin, int inv_freq_size, const int *qkShape, const int *positionIds); + +void llamaApplyRotaryPosEmbeding(void *device, bfloat16_t *query, bfloat16_t *key, int qStride, int kStride, + float *emb_cos, float *emb_sin, int inv_freq_size, const int *qkShape, const int *positionIds); + +void llamaApplyRotaryPosEmbeding(void *device, float16_t *query, float16_t *key, int qStride, int kStride, + float *emb_cos, float *emb_sin, int inv_freq_size, const int *qkShape, const int *positionIds); +#endif + } // namespace xft diff --git a/src/layers/attention.cpp b/src/layers/attention.cpp index 0e9d0669..bd15b0af 100644 --- a/src/layers/attention.cpp +++ b/src/layers/attention.cpp @@ -77,15 +77,16 @@ void AttentionLLaMAImpl(DataType dt, int batchSize, int inputSeqLen, int attHead using ATTENTION = Attention; static std::unordered_map llama_attention_hub; + static MMHelper *mmHelper; static DecoderContext *ctx; static KVCacheManager *kvCacheMgr; if (ctx == nullptr || (ctx != nullptr && (ctx->hiddenSize != hiddenSize || ctx->attHeadSize != attHeadDim))) { if (ctx != nullptr) delete ctx; printf(">> create context: %d %d\n", hiddenSize, attHeadDim); + mmHelper = new MMHelper(Env::getInstance().getEngineKind(), Env::getInstance().getEngineIndex()); ctx = new DecoderContext(1, hiddenSize, attHeadDim, attHeadNum, kvHeadNum, 1, "silu", 1e-6, 0, 0, maxPositions, - maxPosEmbed, -1, 0, 1); - ctx->mmHelper = new MMHelper(Env::getInstance().getEngineKind(), Env::getInstance().getEngineIndex()); + maxPosEmbed, -1, 0, 1, mmHelper); if (kvCacheMgr != nullptr) delete kvCacheMgr; kvCacheMgr = new KVCacheManager(1); } diff --git a/src/layers/attention.h b/src/layers/attention.h index eb6ea4ac..135d9bd4 100644 --- a/src/layers/attention.h +++ b/src/layers/attention.h @@ -25,13 +25,13 @@ #include "gemm_kernel_ext.h" #include "kvcache_tensor.h" #include "matmul_helper.h" +#include "rms_norm.h" +#include "rotary_embedding.h" #include "sequence.h" #include "simple_mem_pool.h" #include "transformer_ctx.h" #include "transformer_util.h" -#include "rotary_embedding.h" - /** * WeiT: weight data type * InT: input data type @@ -46,10 +46,13 @@ template class Attention { public: - Attention(int layerId, DecoderContext *ctx) : layerId(layerId), qkpo(ctx->attHeadSize, ctx->maxPosEmbed) { + Attention(int layerId, DecoderContext *ctx) + : layerId(layerId), qkpo(ctx->attHeadSize, ctx->maxPosEmbed), norm(ctx) { //todo(marvin): clear this code after all rotary_emb refactor - if constexpr (std::is_same::value) { qkpo = LlamaRotaryEmbedding(ctx); } + if constexpr (std::is_same::value) { + qkpo = LlamaRotaryEmbedding(ctx); + } // Group attention or multi-head attention (multi-head attn is a special case of group attn) if (ctx->attHeadNum % ctx->kvHeadNum == 0) { @@ -88,7 +91,6 @@ class Attention { int qResponsibleCols = (this->endQHead - this->startQHead) * headSize; int kvResponsibleCols = (this->endKVHead - this->startKVHead) * headSize; int responsibleCols = qResponsibleCols + 2 * kvResponsibleCols; - qkvWeight.Resize(hiddenSize, responsibleCols); constexpr int sizeFactor = std::is_same_v ? 2 : 1; @@ -138,13 +140,25 @@ class Attention { xft::Matrix convertedqkvWeight; ctx->mmHelper->convertWeight(trans, hiddenSize, responsibleCols, concatBuf, concatScale, concatZero, convertedqkvWeight, qkvWeightScale, qkvWeightZero, qkvWeightSum); + +#ifdef XFT_GPU + xft::Matrix qkvWeightT; + qkvWeightT.Resize(hiddenSize, responsibleCols); + ctx->mmHelper->transposeWeight(trans, convertedqkvWeight, qkvWeightT); + + WeiT *qkvWeiData = (WeiT *)xft::alloc(hiddenSize * responsibleCols * sizeof(WeiT), ctx->device); + qkvWeight.Assign(qkvWeiData, hiddenSize, responsibleCols, responsibleCols); + xft::memcopy(qkvWeight.Data(), qkvWeightT.Data(), hiddenSize * responsibleCols * sizeof(WeiT), ctx->device); +#else + qkvWeight.Resize(hiddenSize, responsibleCols); ctx->mmHelper->packWeight(trans, convertedqkvWeight, qkvWeight); +#endif free(concatBuf); free(concatScale); free(concatZero); -#ifdef DEBUG +#ifdef XFT_DEBUG dbg.debugPrint("attention qkv weight: [%d, %d] (%d)\n", convertedqkvWeight.Rows(), convertedqkvWeight.Cols(), convertedqkvWeight.Stride()); dbg.dumpMatrix(convertedqkvWeight); @@ -165,16 +179,30 @@ class Attention { // Weights for attention output // Horizontally split the weight, as the source (PyTorch weight) is transposed, thus looks like vertically - xft::Matrix convertedWeight; + xft::Matrix convertedOutWeight; ctx->mmHelper->convertWeight(trans, ctx->attHeadNum * ctx->attHeadSize, hiddenSize, attnOutWeight, attnOutScale, - attnOutZero, this->startQHead * headSize, qResponsibleCols, false, convertedWeight, + attnOutZero, this->startQHead * headSize, qResponsibleCols, false, convertedOutWeight, attnOutputWeightScale, attnOutputWeightZero, attnOutputWeightSum, true); - ctx->mmHelper->packWeight(trans, convertedWeight, attnOutputWeight); -#ifdef DEBUG - dbg.debugPrint(">>> attention output weight: [%d, %d] (%d)\n", convertedWeight.Rows(), convertedWeight.Cols(), - convertedWeight.Stride()); - dbg.dumpMatrix(convertedWeight); +#ifdef XFT_GPU + xft::Matrix outWeightT; + outWeightT.Resize(ctx->attHeadNum * ctx->attHeadSize, hiddenSize); + ctx->mmHelper->transposeWeight(trans, convertedOutWeight, outWeightT); + + WeiT *outWeiData + = (WeiT *)xft::alloc(ctx->attHeadNum * ctx->attHeadSize * hiddenSize * sizeof(WeiT), ctx->device); + attnOutputWeight.Assign(outWeiData, ctx->attHeadNum * ctx->attHeadSize, hiddenSize, hiddenSize); + int outWeightTSize = ctx->attHeadNum * ctx->attHeadSize * hiddenSize * sizeof(WeiT); + xft::memcopy(attnOutputWeight.Data(), outWeightT.Data(), outWeightTSize, ctx->device); +#else + attnOutputWeight.Resize(ctx->attHeadNum * ctx->attHeadSize, hiddenSize); + ctx->mmHelper->packWeight(trans, convertedOutWeight, attnOutputWeight); +#endif + +#ifdef XFT_DEBUG + dbg.debugPrint(">>> attention output weight: [%d, %d] (%d)\n", convertedOutWeight.Rows(), + convertedOutWeight.Cols(), convertedOutWeight.Stride()); + dbg.dumpMatrix(convertedOutWeight); dbg.debugPrint("attention output packed weight: [%d, %d] (%d)\n", attnOutputWeight.Rows(), attnOutputWeight.Cols(), attnOutputWeight.Stride()); dbg.dumpMatrix(attnOutputWeight); @@ -194,7 +222,7 @@ class Attention { if (doLNorm) this->norm.setWeight(gamma1, beta1, hiddenSize); } -#ifdef DEBUG +#ifdef XFT_DEBUG void setDebugger(const Debugger &debugger) { this->dbg = debugger; } #endif @@ -238,7 +266,7 @@ class Attention { auto &qkvMatMul = ctx->qkvMatMul; xft::Matrix qkvGroupMatMul((ImT *)qkvMatMul.Data(), qkvRows, qkvCols, qkvStride); -#ifdef DEBUG +#ifdef XFT_DEBUG dbg.debugPrint("---- DecoderLayer.forward (useSelfAttn=%d) ----\n", useSelfAttn); dbg.debugPrint("input:\n"); dbg.dumpMatrix(inputBuffer); @@ -249,7 +277,7 @@ class Attention { norm.forward(inputBuffer.Data(), imBuffer.Data(), inputBuffer.Rows(), inputBuffer.Stride(), imBuffer.Stride(), epsilon); } -#ifdef DEBUG +#ifdef XFT_DEBUG dbg.debugPrint("layer norm:\n"); dbg.dumpMatrix(imBuffer); dbg.debugPrint("qkvWeight [%d, %d]:\n", this->qkvWeight.Rows(), this->qkvWeight.Cols()); @@ -273,7 +301,7 @@ class Attention { xft::Matrix key(qkvGroupMatMul, 0, inputBuffer.Rows(), qCols, kvCols); xft::Matrix value(qkvGroupMatMul, 0, inputBuffer.Rows(), qkCols, kvCols); -#ifdef DEBUG +#ifdef XFT_DEBUG dbg.debugPrint("Q[%d,%d](%d):\n", query.Rows(), query.Cols(), query.Stride()); dbg.dumpMatrix(query); dbg.debugPrint("K[%d,%d](%d):\n", key.Rows(), key.Cols(), key.Stride()); @@ -301,13 +329,22 @@ class Attention { } t3.release(); -#ifdef DEBUG +#ifdef XFT_DEBUG dbg.debugPrint("Q[%d,%d](%d) after post op:\n", query.Rows(), query.Cols(), query.Stride()); dbg.dumpMatrix(query); dbg.debugPrint("K[%d,%d](%d) after post op:\n", key.Rows(), key.Cols(), key.Stride()); dbg.dumpMatrix(key); #endif +#ifdef XFT_GPU + int64_t qkvSize = qkvRows * qkvStride * sizeof(ImT); + ImT *qkvTmp = (ImT *)xft::alloc(qkvSize); + xft::memcopy(qkvTmp, qkvGroupMatMul.Data(), qkvSize, ctx->device); // error: need CPU ptr and GPU ptr + query.Assign(qkvTmp, inputBuffer.Rows(), qCols, qkvCols); + key.Assign(qkvTmp + qCols, inputBuffer.Rows(), kvCols, qkvCols); + value.Assign(qkvTmp + qCols + kvCols, inputBuffer.Rows(), kvCols, qkvCols); +#endif + // Revise attnFactor before softmax (for some models, attnFactor may be not the default value) // We initially introduced the code for ChatGLM, but eventually found it has no difference and was unnecessary. // However, we have chosen to keep it in the codebase in case it becomes useful for future models. @@ -324,6 +361,12 @@ class Attention { // For multiple nodes inference, not the whole result buffer xft::Matrix attnSplit(imBuffer.Data(), imBuffer.Rows(), qCols, qCols); +#ifdef XFT_GPU + int64_t attnSplitSize = imBuffer.Rows() * qCols * sizeof(ImT); + ImT *attnSplitTmp = (ImT *)xft::alloc(attnSplitSize); + attnSplit.Assign(attnSplitTmp, imBuffer.Rows(), qCols, qCols); +#endif + if (pastSeqLen == 0) { if (ctx->inputSeqLen > getFlashThresh()) { flashAttention(ctx, query, key, value, attnSplit, presentKey, presentValue, attnMask, pastSeqLen); @@ -337,7 +380,13 @@ class Attention { } t4.release(); -#ifdef DEBUG +#ifdef XFT_GPU + xft::memcopy(imBuffer.Data(), attnSplit.Data(), attnSplitSize, ctx->device); + attnSplit.Assign(imBuffer.Data(), imBuffer.Rows(), qCols, qCols); + xft::dealloc(qkvTmp); +#endif + +#ifdef XFT_DEBUG dbg.debugPrint(">>> attention_%d (softmax * value): [%d, %d] (%d)\n", ctx->splitIdx, attnSplit.Rows(), attnSplit.Cols(), attnSplit.Stride()); dbg.dumpMatrix(attnSplit); @@ -380,7 +429,7 @@ class Attention { } t5.release(); -#ifdef DEBUG +#ifdef XFT_DEBUG dbg.debugPrint(">>> attention output/projection[%d, %d] (%d):\n", outBuffer.Rows(), outBuffer.Cols(), outBuffer.Stride()); dbg.dumpMatrix(outBuffer); @@ -389,7 +438,7 @@ class Attention { if (doLnAfter) { TimeLine t6("result.layer_norm"); norm.forward(outBuffer.Data(), outBuffer.Data(), outBuffer.Rows(), outBuffer.Stride(), outBuffer.Stride()); -#ifdef DEBUG +#ifdef XFT_DEBUG dbg.debugPrint("LayerNorm after attention: [%d, %d] (%d)\n", outBuffer.Rows(), outBuffer.Cols(), outBuffer.Stride()); dbg.dumpMatrix(outBuffer); @@ -423,7 +472,7 @@ class Attention { auto &qkvMatMul = ctx->qkvMatMul; xft::Matrix qkvGroupMatMul((ImT *)qkvMatMul.Data(), qkvRows, qkvCols, qkvStride); -#ifdef DEBUG +#ifdef XFT_DEBUG dbg.debugPrint("---- DecoderLayer.forward ----\n"); dbg.debugPrint("input:\n"); dbg.dumpMatrix(inputBuffer); @@ -434,7 +483,7 @@ class Attention { norm.forward(inputBuffer.Data(), imBuffer.Data(), inputBuffer.Rows(), inputBuffer.Stride(), imBuffer.Stride(), epsilon); } -#ifdef DEBUG +#ifdef XFT_DEBUG dbg.debugPrint("layer norm:\n"); dbg.dumpMatrix(imBuffer); dbg.debugPrint("qkvWeight [%d, %d]:\n", this->qkvWeight.Rows(), this->qkvWeight.Cols()); @@ -458,7 +507,7 @@ class Attention { xft::Matrix key(qkvGroupMatMul, 0, inputBuffer.Rows(), qCols, kvCols); xft::Matrix value(qkvGroupMatMul, 0, inputBuffer.Rows(), qkCols, kvCols); -#ifdef DEBUG +#ifdef XFT_DEBUG dbg.debugPrint("Q[%d,%d](%d):\n", query.Rows(), query.Cols(), query.Stride()); dbg.dumpMatrix(query); dbg.debugPrint("K[%d,%d](%d):\n", key.Rows(), key.Cols(), key.Stride()); @@ -488,7 +537,7 @@ class Attention { } t3.release(); -#ifdef DEBUG +#ifdef XFT_DEBUG dbg.debugPrint("Q[%d,%d](%d) after post op:\n", query.Rows(), query.Cols(), query.Stride()); dbg.dumpMatrix(query); dbg.debugPrint("K[%d,%d](%d) after post op:\n", key.Rows(), key.Cols(), key.Stride()); @@ -524,7 +573,7 @@ class Attention { } t4.release(); -#ifdef DEBUG +#ifdef XFT_DEBUG dbg.debugPrint(">>> attention_%d (softmax * value): [%d, %d] (%d)\n", ctx->splitIdx, attnSplit.Rows(), attnSplit.Cols(), attnSplit.Stride()); dbg.dumpMatrix(attnSplit); @@ -567,7 +616,7 @@ class Attention { } t5.release(); -#ifdef DEBUG +#ifdef XFT_DEBUG dbg.debugPrint(">>> attention output/projection[%d, %d] (%d):\n", outBuffer.Rows(), outBuffer.Cols(), outBuffer.Stride()); dbg.dumpMatrix(outBuffer); @@ -576,7 +625,7 @@ class Attention { if (!doLnBefore) { TimeLine t6("result.layer_norm"); norm.forward(outBuffer.Data(), outBuffer.Data(), outBuffer.Rows(), outBuffer.Stride(), outBuffer.Stride()); -#ifdef DEBUG +#ifdef XFT_DEBUG dbg.debugPrint("LayerNorm after attention: [%d, %d] (%d)\n", outBuffer.Rows(), outBuffer.Cols(), outBuffer.Stride()); dbg.dumpMatrix(outBuffer); @@ -894,7 +943,7 @@ class Attention { this->gemm1(A, keyMatInfo, C, m, n, headSize, lda, ldc); -#ifdef DEBUG +#ifdef XFT_DEBUG if (b == 0 && i == 0) { dbg.debugPrint("Q * K, first head:\n"); auto p = scoreBuf; @@ -907,7 +956,7 @@ class Attention { // Softmax(Q * K) this->softmax(ctx, C, getMask(attnMask, b, i, queryLen, keyLen), m, n, ldc, startSeq); -#ifdef DEBUG +#ifdef XFT_DEBUG if (b == 0 && i == 0) { dbg.debugPrint("Softmax(Q * K), first head:\n"); auto p = scoreBuf; @@ -925,7 +974,7 @@ class Attention { auto output = result.Row(b * ctx->inputSeqLen + startSeq) + i * ctx->attHeadSize; this->gemm2(C, valueMat, output, m, headSize, keyLen, scoreStride, result.Stride()); -#ifdef DEBUG +#ifdef XFT_DEBUG if (b == 0 && i == 0) { dbg.debugPrint("Softmax(Q * K) * V, first head:\n"); auto p = output; @@ -1166,7 +1215,7 @@ class Attention { int endQHead; int startKVHead; int endKVHead; -#ifdef DEBUG +#ifdef XFT_DEBUG Debugger dbg; #endif }; diff --git a/src/layers/decoder_block.h b/src/layers/decoder_block.h index 4a18c13d..ba352066 100644 --- a/src/layers/decoder_block.h +++ b/src/layers/decoder_block.h @@ -147,7 +147,7 @@ class DecoderBlock { int kvSize = attHeadSize * kvHeadNum; int qkvSize = qSize + 2 * kvSize; -#define ALLOC(size, alignment) xft::alloc((size), (alignment)) +#define ALLOC(size, alignment) xft::alloc((size), nullptr, (alignment)) OriWeiT *qkvWeight = (OriWeiT *)ALLOC(hiddenSize * qkvSize * sizeof(OriWeiT), 64); float *qkvScales = nullptr; float *qkvZeros = nullptr; diff --git a/src/layers/decoder_layer.cpp b/src/layers/decoder_layer.cpp index d1017648..02f13cbf 100644 --- a/src/layers/decoder_layer.cpp +++ b/src/layers/decoder_layer.cpp @@ -85,6 +85,7 @@ void LayerLLaMAImpl(DataType dt, ActivationType at, NormType nt, int batchSize, using DECODER = Decoder, LlamaMLP>; static std::unordered_map llama_layer_hub; + static MMHelper *mmHelper; static DecoderContext *ctx; static KVCacheManager *kvCacheMgr; @@ -104,9 +105,9 @@ void LayerLLaMAImpl(DataType dt, ActivationType at, NormType nt, int batchSize, || (ctx != nullptr && (ctx->hiddenSize != hiddenSize || ctx->intermediateSize != intermediateSize))) { if (ctx != nullptr) delete ctx; printf(">> create context: %d %d\n", hiddenSize, intermediateSize); + mmHelper = new MMHelper(Env::getInstance().getEngineKind(), Env::getInstance().getEngineIndex()); ctx = new DecoderContext(1, hiddenSize, attHeadDim, attHeadNum, kvHeadNum, intermediateSize, actType, 1e-6, 0, - 0, maxPositions, maxPosEmbed, -1, 0, 1); - ctx->mmHelper = new MMHelper(Env::getInstance().getEngineKind(), Env::getInstance().getEngineIndex()); + 0, maxPositions, maxPosEmbed, -1, 0, 1, mmHelper); if (kvCacheMgr != nullptr) delete kvCacheMgr; kvCacheMgr = new KVCacheManager(1); } diff --git a/src/layers/decoder_layer.h b/src/layers/decoder_layer.h index 9a44b13f..3cb58736 100644 --- a/src/layers/decoder_layer.h +++ b/src/layers/decoder_layer.h @@ -59,11 +59,11 @@ class Decoder { : layerIdx(_layerIdx) , attn(_layerIdx, _ctx) , mlp(_ctx) -#ifdef DEBUG +#ifdef XFT_DEBUG , dbg(Debugger::formatStr("%d_%d.csv", _layerIdx, _ctx->splitIdx)) #endif { -#ifdef DEBUG +#ifdef XFT_DEBUG attn.setDebugger(dbg); mlp.setDebugger(dbg); #endif @@ -126,7 +126,7 @@ class Decoder { ATTN_CLS attn; MLP_CLS mlp; -#ifdef DEBUG +#ifdef XFT_DEBUG Debugger dbg; #endif }; diff --git a/src/layers/dist_linear.h b/src/layers/dist_linear.h index 41c17bf0..b118b5fb 100644 --- a/src/layers/dist_linear.h +++ b/src/layers/dist_linear.h @@ -59,14 +59,25 @@ class DistLinear { int K = inputSize; int N = this->splitSize; - weight.Resize(K, N); + scaleWeight.Resize(N); zeroWeight.Resize(N); xft::Matrix quantizedWeight; ctx->mmHelper->convertWeight( true, K, N, w + splitOffset * K, nullptr, nullptr, quantizedWeight, scaleWeight, zeroWeight, sumWeight); +#ifdef XFT_GPU + xft::Matrix tWeight; + tWeight.Resize(K, N); + ctx->mmHelper->transposeWeight(true, quantizedWeight, tWeight); + + WeiT *input_data = (WeiT *)xft::alloc(K * N * sizeof(WeiT), ctx->device); + weight.Assign(input_data, K, N, N); + xft::memcopy(weight.Data(), tWeight.Data(), tWeight.Rows() * tWeight.Cols() * sizeof(WeiT), ctx->device); +#else + weight.Resize(K, N); ctx->mmHelper->packWeight(true, quantizedWeight, weight); +#endif // Copy Bias if (b) { diff --git a/src/layers/layer_norm.cpp b/src/layers/layer_norm.cpp index 4ebfc6b7..012f8460 100644 --- a/src/layers/layer_norm.cpp +++ b/src/layers/layer_norm.cpp @@ -31,17 +31,24 @@ LayerNorm::LayerNorm() { normSize = 0; } +LayerNorm::LayerNorm(DecoderContext *ctx) { + device = ctx->device; + gamma = nullptr; + beta = nullptr; + normSize = 0; +} + LayerNorm::~LayerNorm() { - if (gamma) { free(gamma); } - if (beta) { free(beta); } + if (gamma) { xft::dealloc(gamma, device); } + if (beta) { xft::dealloc(beta, device); } } void LayerNorm::setWeight(const float *gamma, const float *beta, int cols) { this->normSize = cols; - this->gamma = (float *)xft::alloc(cols * sizeof(float)); - this->beta = (float *)xft::alloc(cols * sizeof(float)); - memcpy(this->gamma, gamma, cols * sizeof(float)); - memcpy(this->beta, beta, cols * sizeof(float)); + this->gamma = (float *)xft::alloc(cols * sizeof(float), device); + this->beta = (float *)xft::alloc(cols * sizeof(float), device); + xft::memcopy(this->gamma, gamma, cols * sizeof(float), device); + xft::memcopy(this->beta, beta, cols * sizeof(float), device); } void LayerNorm::setWeight(const std::string &gammaPath, const std::string &betaPath, int cols) { @@ -52,11 +59,21 @@ void LayerNorm::setWeight(const std::string &gammaPath, const std::string &betaP // input and output are in shape of (rows, normSize) // TODO: column-wise parallel +#ifdef XFT_GPU +void LayerNorm::forward(const float *input, float *output, int rows, int iStride, int oStride, float epsilon) { + TimeLine t("LayerNorm.forward"); + const float *pgamma = gamma; + const float *pbeta = beta; + // TODO: Add LayerNorm Impl + printf("%s:%d: Could not forward in LayerNorm with undefined data type.\n", __FILE__, __LINE__); + exit(-1); +} +#else void LayerNorm::forward(const float *input, float *output, int rows, int iStride, int oStride, float epsilon) { TimeLine t("LayerNorm.forward"); const float *pgamma = gamma; const float *pbeta = beta; invokeLayerNorm(output, input, pgamma, pbeta, rows, normSize, iStride, oStride); } - +#endif } // namespace xft \ No newline at end of file diff --git a/src/layers/layer_norm.h b/src/layers/layer_norm.h index 8b554648..75d9409b 100644 --- a/src/layers/layer_norm.h +++ b/src/layers/layer_norm.h @@ -16,6 +16,7 @@ #include #include "weight_util.h" +#include "transformer_ctx.h" namespace xft { @@ -23,6 +24,7 @@ namespace xft { class LayerNorm { public: LayerNorm(); + LayerNorm(DecoderContext *ctx); ~LayerNorm(); void setWeight(const float *gamma, const float *beta, int cols); @@ -37,6 +39,7 @@ class LayerNorm { float *gamma = nullptr; float *beta = nullptr; + void *device = nullptr; }; } // namespace xft \ No newline at end of file diff --git a/src/layers/mlp_chatglm2.h b/src/layers/mlp_chatglm2.h index f2f488d0..885a6280 100644 --- a/src/layers/mlp_chatglm2.h +++ b/src/layers/mlp_chatglm2.h @@ -94,7 +94,7 @@ class ChatGLM2MLP : public LlamaMLP { ctx->mmHelper->convertWeight(ctx, trans, intermediateSize, hiddenSize, downW, nullptr, nullptr, false, convertedDownWeight, this->downWeightScale, this->downWeightZero, this->downWeightSum); ctx->mmHelper->packWeight(trans, convertedDownWeight, this->downWeight); -#ifdef DEBUG +#ifdef XFT_DEBUG this->dbg.debugPrint("convertedGateWeight [%d, %d](%d):\n", convertedGateWeight.Rows(), convertedGateWeight.Cols(), convertedGateWeight.Stride()); this->dbg.dumpMatrix(convertedGateWeight); @@ -120,9 +120,10 @@ class ChatGLM2MLP : public LlamaMLP { this->dbg.dumpMatrix(this->downWeight); #endif // norm.setWeight(normW, NULL, hiddenSize); - if (normW) { - this->normWeight.Resize(hiddenSize); - memcpy(this->normWeight.Data(), normW, sizeof(float) * hiddenSize); - } + + if (normW) { norm.setWeight(normW, nullptr, hiddenSize); } } + +private: + using LlamaMLP::norm; }; diff --git a/src/layers/mlp_llama.cpp b/src/layers/mlp_llama.cpp index 816c5d69..50d5e79e 100644 --- a/src/layers/mlp_llama.cpp +++ b/src/layers/mlp_llama.cpp @@ -26,6 +26,7 @@ void MLPLLaMAImpl(DataType dt, ActivationType at, int numTokens, int hiddenSize, using MLP = LlamaMLP; static std::unordered_map llama_mlp_hub; + static MMHelper *mmHelper; static DecoderContext *ctx; std::string actType; @@ -44,8 +45,8 @@ void MLPLLaMAImpl(DataType dt, ActivationType at, int numTokens, int hiddenSize, || (ctx != nullptr && (ctx->hiddenSize != hiddenSize || ctx->intermediateSize != intermediateSize))) { if (ctx != nullptr) delete ctx; printf(">> create context: %d %d\n", hiddenSize, intermediateSize); - ctx = new DecoderContext(1, hiddenSize, 1, 1, 1, intermediateSize, actType, 1e-6, 0, 0, 0, 0, 0, 0, 1); - ctx->mmHelper = new MMHelper(Env::getInstance().getEngineKind(), Env::getInstance().getEngineIndex()); + mmHelper = new MMHelper(Env::getInstance().getEngineKind(), Env::getInstance().getEngineIndex()); + ctx = new DecoderContext(1, hiddenSize, 1, 1, 1, intermediateSize, actType, 1e-6, 0, 0, 0, 0, 0, 0, 1, mmHelper); } // create hash key and value: if hidden and intermediateSize is changed , then memory pointer is also changed. @@ -58,7 +59,7 @@ void MLPLLaMAImpl(DataType dt, ActivationType at, int numTokens, int hiddenSize, auto it_created = llama_mlp_hub.find(llama_mlp_key); if (it_created == llama_mlp_hub.end()) { // MLP &llama_mlp = MLP::getInstance(); - llama_mlp = new MLP(); + llama_mlp = new MLP(ctx); llama_mlp->setWeights(ctx, (float *)gateWeight, nullptr, nullptr, nullptr, (float *)upWeight, nullptr, nullptr, nullptr, nullptr, nullptr, (float *)downWeight, nullptr, nullptr, false); llama_mlp_hub[llama_mlp_key] = llama_mlp; diff --git a/src/layers/mlp_llama.h b/src/layers/mlp_llama.h index 8a3bda34..334644bb 100644 --- a/src/layers/mlp_llama.h +++ b/src/layers/mlp_llama.h @@ -14,16 +14,12 @@ // ============================================================================ #pragma once -#ifdef UNDEBUG -#undef NDEBUG -#endif - #include "bert_util.h" #include "copy_util.h" #include "debugger.h" #include "decoder_util.h" #include "matmul_helper.h" -#include "rmsnorm_kernels.h" +#include "rms_norm.h" #include "simple_mem_pool.h" #include "singleton.h" #include "timeline.h" @@ -38,12 +34,11 @@ // def forward(self, x): // return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) // But please also be noted: we extended the MLP to include layer norm -template -class LlamaMLP : public SingletonBase> { +template +class LlamaMLP { public: - LlamaMLP() {} - - LlamaMLP(DecoderContext *ctx) {} + LlamaMLP(DecoderContext *ctx) : norm(ctx) {} // OriWeiT: float, int8_t or uint4x2_t template @@ -61,7 +56,6 @@ class LlamaMLP : public SingletonBase> { xft::Matrix quantizedGateWeight, quantizedUpWeight, quantizedDownWeight; auto it = SplitUtil::getTaskRange(imSize, ctx->numSplit, ctx->splitIdx); - downWeight.Resize(it.second - it.first, hiddenSize); ctx->mmHelper->convertWeight(ctx, trans, hiddenSize, imSize, gateW, gateS, gateZ, true, quantizedGateWeight, gateWeightScale, gateWeightZero, gateWeightSum); @@ -80,15 +74,41 @@ class LlamaMLP : public SingletonBase> { catWeightsSum); quantizedGateWeight.Release(); quantizedUpWeight.Release(); + +#ifdef XFT_GPU + xft::Matrix catWeightsT; + int catWeiRows = quantizedCatWeights.Rows(); + int catWeiCols = quantizedCatWeights.Cols(); + catWeightsT.Resize(catWeiRows, catWeiCols); + ctx->mmHelper->transposeWeight(trans, quantizedCatWeights, catWeightsT); + + WeiT *catWeiData = (WeiT *)xft::alloc(catWeiRows * catWeiCols * sizeof(WeiT), ctx->device); + catWeights.Assign(catWeiData, catWeiRows, catWeiCols, catWeiCols); + xft::memcopy(catWeights.Data(), catWeightsT.Data(), catWeiRows * catWeiCols * sizeof(WeiT), ctx->device); +#else catWeights.Resize(quantizedCatWeights.Rows(), quantizedCatWeights.Cols()); ctx->mmHelper->packWeight(trans, quantizedCatWeights, catWeights); +#endif } // Horizontally split the down weight ctx->mmHelper->convertWeight(ctx, trans, imSize, hiddenSize, downW, downS, downZ, false, quantizedDownWeight, downWeightScale, downWeightZero, downWeightSum); +#ifdef XFT_GPU + xft::Matrix downWeightT; + int downWeiRows = it.second - it.first; + int downWeiCols = hiddenSize; + downWeightT.Resize(downWeiRows, downWeiCols); + ctx->mmHelper->transposeWeight(trans, quantizedDownWeight, downWeightT); + + WeiT *downWeiData = (WeiT *)xft::alloc(downWeiRows * downWeiCols * sizeof(WeiT), ctx->device); + downWeight.Assign(downWeiData, downWeiRows, downWeiCols, downWeiCols); + xft::memcopy(downWeight.Data(), downWeightT.Data(), downWeiRows * downWeiCols * sizeof(WeiT), ctx->device); +#else + downWeight.Resize(it.second - it.first, hiddenSize); ctx->mmHelper->packWeight(trans, quantizedDownWeight, downWeight); +#endif -#ifdef DEBUG +#ifdef XFT_DEBUG dbg.debugPrint("quantizedGateWeight:\n"); dbg.dumpMatrix(quantizedGateWeight); @@ -100,13 +120,10 @@ class LlamaMLP : public SingletonBase> { #endif // LlamaRMSNorm - if (normW) { - normWeight.Resize(hiddenSize); - memcpy(normWeight.Data(), normW, sizeof(float) * hiddenSize); - } + if (normW) { norm.setWeight(normW, nullptr, hiddenSize); } } -#ifdef DEBUG +#ifdef XFT_DEBUG void setDebugger(const Debugger &debugger) { this->dbg = debugger; } #endif @@ -126,11 +143,10 @@ class LlamaMLP : public SingletonBase> { (ImT *)ctx->normBuf.Data(), ctx->normBuf.Rows(), ctx->normBuf.Cols(), ctx->normBuf.Stride()); if (doLnBefore == true) { - xft::rmsNorm(normBuffer.Data(), inBuffer.Data(), normWeight.Data(), M, hiddenSize, inBuffer.Stride(), - normBuffer.Stride(), 1e-6); + norm.forward(inBuffer.Data(), normBuffer.Data(), M, inBuffer.Stride(), normBuffer.Stride(), 1e-6); } -#ifdef DEBUG +#ifdef XFT_DEBUG dbg.debugPrint("LayerNorm before MLP:\n"); dbg.dumpMatrix(normBuffer); dbg.debugPrint(">>> residential: [%d, %d] (%d)\n", inBuffer.Rows(), inBuffer.Cols(), inBuffer.Stride()); @@ -142,7 +158,7 @@ class LlamaMLP : public SingletonBase> { (ImT *)ctx->imOut.Data(), ctx->imOut.Rows(), ctx->imOut.Cols(), ctx->imOut.Stride()); gateProj(ctx, doLnBefore ? normBuffer : inBuffer, imBuffer); -#ifdef DEBUG +#ifdef XFT_DEBUG dbg.debugPrint( ">>> gateWeight: [%d, %d] (%d)\n", gateWeight.Rows(), gateWeight.Cols(), gateWeight.Stride()); dbg.dumpMatrix(gateWeight); @@ -152,7 +168,7 @@ class LlamaMLP : public SingletonBase> { upProj(ctx, doLnBefore ? normBuffer : inBuffer, imBuffer); -#ifdef DEBUG +#ifdef XFT_DEBUG dbg.debugPrint(">>> upWeight: [%d, %d] (%d)\n", upWeight.Rows(), upWeight.Cols(), upWeight.Stride()); dbg.dumpMatrix(upWeight); dbg.debugPrint(">>> up output: [%d, %d] (%d)\n", imBuffer.Rows(), imBuffer.Cols(), imBuffer.Stride()); @@ -168,9 +184,9 @@ class LlamaMLP : public SingletonBase> { // Need to allocate extra buffer as oneDNN does not support the case of stride > cols const int cols = N / 2; auto bufSize = sizeof(ImT) * M * cols; - ImT *t = (ImT *)SimpleMemPool::instance().getBuffer("mlp_silu", bufSize); + ImT *t = (ImT *)SimpleMemPool::instance().getBuffer("mlp_silu", bufSize, ctx->device); xft::Matrix siluBuf(t, M, cols, cols); -#ifdef DEBUG +#ifdef XFT_DEBUG dbg.debugPrint( ">>> enableCATMLP imBuffer: [%d, %d] (%d)\n", imBuffer.Rows(), imBuffer.Cols(), imBuffer.Stride()); dbg.dumpMatrix(imBuffer); @@ -178,7 +194,7 @@ class LlamaMLP : public SingletonBase> { dbg.dumpMatrix(inBuffer); #endif catGateUpProj(ctx, doLnBefore ? normBuffer : inBuffer, imBuffer, siluBuf); -#ifdef DEBUG +#ifdef XFT_DEBUG dbg.debugPrint("catWeights:\n"); dbg.dumpMatrix(catWeights); dbg.debugPrint("gateUp output:\n"); @@ -189,7 +205,7 @@ class LlamaMLP : public SingletonBase> { downProj(ctx, siluBuf, outBuffer, inBuffer, ctx->splitIdx == 0); } -#ifdef DEBUG +#ifdef XFT_DEBUG dbg.debugPrint(">>> downWeight: [%d, %d] (%d)\n", downWeight.Rows(), downWeight.Cols(), downWeight.Stride()); dbg.dumpMatrix(downWeight); dbg.debugPrint(">>> residential: [%d, %d] (%d)\n", inBuffer.Rows(), inBuffer.Cols(), inBuffer.Stride()); @@ -298,11 +314,11 @@ class LlamaMLP : public SingletonBase> { // Compute silu on the left half and then add it with the right half if (ctx->actType == DecoderContext::SILU) { - DecoderUtil::siluSum(output, siluBuf); + DecoderUtil::siluSum(output, siluBuf, ctx->device); } else if (ctx->actType == DecoderContext::SWIGLU) { // chatglm2/3 - DecoderUtil::siluSum(output, siluBuf); + DecoderUtil::siluSum(output, siluBuf, ctx->device); } else if (ctx->actType == DecoderContext::GELU) { // gemma - DecoderUtil::geluSum(output, siluBuf); + DecoderUtil::geluSum(output, siluBuf, ctx->device); } else { printf("ERROR: unsupported activation in MLP.\n"); exit(-1); @@ -364,9 +380,9 @@ class LlamaMLP : public SingletonBase> { xft::Vector downWeightSum; // For int8_t weight // LlamaRMSNorm param - xft::Vector normWeight; + NORM_CLS norm; -#ifdef DEBUG +#ifdef XFT_DEBUG Debugger dbg; #endif }; diff --git a/src/layers/mlp_standard.h b/src/layers/mlp_standard.h index 0e9c6123..a4d65170 100644 --- a/src/layers/mlp_standard.h +++ b/src/layers/mlp_standard.h @@ -71,7 +71,7 @@ class MLP { } } -#ifdef DEBUG +#ifdef XFT_DEBUG void setDebugger(const Debugger &debugger) { this->dbg = debugger; } #endif @@ -99,7 +99,7 @@ class MLP { auto &imInput = doLnBefore ? (INPUT_AS_RESID ? resultBuffer1 : resultBuffer2) : resultBuffer2; -#ifdef DEBUG +#ifdef XFT_DEBUG dbg.debugPrint("layer norm after attention:\n"); dbg.dumpMatrix(imInput); #endif @@ -110,7 +110,7 @@ class MLP { case DecoderContext::GELU: intermediate_gelu(ctx, imInput, imBuffer); break; } -#ifdef DEBUG +#ifdef XFT_DEBUG dbg.debugPrint("intermediate:\n"); dbg.dumpMatrix(imBuffer); #endif @@ -149,7 +149,7 @@ class MLP { } } -#ifdef DEBUG +#ifdef XFT_DEBUG dbg.debugPrint("output:\n"); dbg.dumpMatrix(resultBuffer1); #endif @@ -157,7 +157,7 @@ class MLP { // layerNorm if (!doLnBefore) { DecoderUtil::layerNorm(resultBuffer1, resultBuffer1, gamma2, beta2); } -#ifdef DEBUG +#ifdef XFT_DEBUG dbg.debugPrint("final output:\n"); dbg.dumpMatrix(resultBuffer1); #endif @@ -239,7 +239,7 @@ class MLP { // layerNorm param xft::Vector gamma2, beta2; -#ifdef DEBUG +#ifdef XFT_DEBUG Debugger dbg; #endif }; diff --git a/src/layers/rms_norm.cpp b/src/layers/rms_norm.cpp index 0ca054ff..0fb0fe01 100644 --- a/src/layers/rms_norm.cpp +++ b/src/layers/rms_norm.cpp @@ -21,53 +21,180 @@ #include "rms_norm.h" #include "rmsnorm_kernels.h" #include "timeline.h" +#include "transformer_ctx.h" + +#ifdef XFT_GPU +#include "gpudnn/gpu_layernorm_kernels.h" +#include +#endif namespace xft { -RmsNorm::RmsNorm() { +template +RmsNormImp::RmsNormImp() { + weight = nullptr; + normSize = 0; +} + +template +RmsNormImp::RmsNormImp(DecoderContext *ctx) { + device = ctx->device; weight = nullptr; normSize = 0; } -RmsNorm::~RmsNorm() { - if (weight) { free(weight); } +template +RmsNormImp::~RmsNormImp() { + if (weight) { xft::dealloc(weight, device); } } -void RmsNorm::setWeight(const float *w, const float *, int cols) { +template +void RmsNormImp::setWeight(const float *w, const float *, int cols) { + T weightBuf[cols]; + if constexpr (std::is_same_v) { + xft::memcopy(weightBuf, w, cols * sizeof(float)); + } else if constexpr (std::is_same_v) { + float16_t::cvt_float_to_float16(w, weightBuf, cols); + } else if constexpr (std::is_same_v) { + bfloat16_t::cvt_float_to_bfloat16(w, weightBuf, cols); + } else { + printf("%s:%d: Could not setWeight in RmsNorm with undefined data type.\n", __FILE__, __LINE__); + exit(-1); + } + this->normSize = cols; - this->weight = (float *)xft::alloc(cols * sizeof(float)); - memcpy(weight, w, cols * sizeof(float)); + this->weight = (T *)xft::alloc(cols * sizeof(T), device); + xft::memcopy(this->weight, weightBuf, cols * sizeof(T), device); } -void RmsNorm::setWeight(const std::string &modelPath, const std::string &, int cols) { +template +void RmsNormImp::setWeight(const std::string &modelPath, const std::string &, int cols) { + float weightBuf[cols]; + float *weiBuf = &weightBuf[0]; + loadWeight(modelPath, weiBuf, cols); this->normSize = cols; - loadWeight(modelPath, weight, cols); + this->setWeight(weightBuf, nullptr, cols); +} + +#ifdef XFT_GPU +template +void RmsNormImp::forward(const float *input, float *output, int rows, int iStride, int oStride, float epsilon) { + TimeLine t("RmsNorm.forward"); + sycl::queue *gpu_queue = static_cast(device); + if constexpr (std::is_same_v) { + fastertransformer::invokeGeneralT5LayerNorm( + output, input, weight, (const float *)nullptr, epsilon, rows, normSize, gpu_queue); + } else { + printf("%s:%d: Could not forward in RmsNorm with undefined data type.\n", __FILE__, __LINE__); + exit(-1); + } +} + +template +void RmsNormImp::forward(const float *input, bfloat16_t *output, int rows, int iStride, int oStride, float epsilon) { + TimeLine t("RmsNorm.forward"); + printf("%s:%d: Could not forward in RmsNorm with undefined data type.\n", __FILE__, __LINE__); + exit(-1); +} + +template +void RmsNormImp::forward( + const bfloat16_t *input, bfloat16_t *output, int rows, int iStride, int oStride, float epsilon) { + TimeLine t("RmsNorm.forward"); + sycl::queue *gpu_queue = static_cast(device); + if constexpr (std::is_same_v) { + // TODO: Add BF16 RmsNorm Implemention. + // fastertransformer::invokeGeneralT5LayerNorm( + // output, input, weight, (const bfloat16_t *)nullptr, epsilon, rows, normSize, gpu_queue); + } else { + printf("%s:%d: Could not forward in RmsNorm with undefined data type.\n", __FILE__, __LINE__); + exit(-1); + } } +template +void RmsNormImp::forward(const float *input, float16_t *output, int rows, int iStride, int oStride, float epsilon) { + TimeLine t("RmsNorm.forward"); + printf("%s:%d: Could not forward in RmsNorm with undefined data type.\n", __FILE__, __LINE__); + exit(-1); +} + +template +void RmsNormImp::forward( + const float16_t *input, float16_t *output, int rows, int iStride, int oStride, float epsilon) { + TimeLine t("RmsNorm.forward"); + sycl::queue *gpu_queue = static_cast(device); + if constexpr (std::is_same_v) { + fastertransformer::invokeGeneralT5LayerNorm((sycl::half *)output, (const sycl::half *)input, + (const sycl::half *)weight, (const sycl::half *)nullptr, epsilon, rows, normSize, gpu_queue); + } else { + printf("%s:%d: Could not forward in RmsNorm with undefined data type.\n", __FILE__, __LINE__); + exit(-1); + } +} + +#else // input and output are in shape of (rows, normSize) -void RmsNorm::forward(const float *input, float *output, int rows, int iStride, int oStride, float epsilon) { +template +void RmsNormImp::forward(const float *input, float *output, int rows, int iStride, int oStride, float epsilon) { TimeLine t("RmsNorm.forward"); - rmsNorm(output, input, weight, rows, normSize, iStride, oStride, epsilon); + if constexpr (std::is_same_v) { + rmsNorm(output, input, weight, rows, normSize, iStride, oStride, epsilon); + } else { + printf("%s:%d: Could not forward in RmsNorm with undefined data type.\n", __FILE__, __LINE__); + exit(-1); + } } -void RmsNorm::forward(const float *input, bfloat16_t *output, int rows, int iStride, int oStride, float epsilon) { +template +void RmsNormImp::forward(const float *input, bfloat16_t *output, int rows, int iStride, int oStride, float epsilon) { TimeLine t("RmsNorm.forward"); - rmsNorm(output, input, weight, rows, normSize, iStride, oStride, epsilon); + if constexpr (std::is_same_v) { + rmsNorm(output, input, weight, rows, normSize, iStride, oStride, epsilon); + } else { + printf("%s:%d: Could not forward in RmsNorm with undefined data type.\n", __FILE__, __LINE__); + exit(-1); + } } -void RmsNorm::forward(const bfloat16_t *input, bfloat16_t *output, int rows, int iStride, int oStride, float epsilon) { +template +void RmsNormImp::forward( + const bfloat16_t *input, bfloat16_t *output, int rows, int iStride, int oStride, float epsilon) { TimeLine t("RmsNorm.forward"); - rmsNorm(output, input, weight, rows, normSize, iStride, oStride, epsilon); + if constexpr (std::is_same_v) { + rmsNorm(output, input, weight, rows, normSize, iStride, oStride, epsilon); + } else { + printf("%s:%d: Could not forward in RmsNorm with undefined data type.\n", __FILE__, __LINE__); + exit(-1); + } } -void RmsNorm::forward(const float *input, float16_t *output, int rows, int iStride, int oStride, float epsilon) { +template +void RmsNormImp::forward(const float *input, float16_t *output, int rows, int iStride, int oStride, float epsilon) { TimeLine t("RmsNorm.forward"); - rmsNorm(output, input, weight, rows, normSize, iStride, oStride, epsilon); + if constexpr (std::is_same_v) { + rmsNorm(output, input, weight, rows, normSize, iStride, oStride, epsilon); + } else { + printf("%s:%d: Could not forward in RmsNorm with undefined data type.\n", __FILE__, __LINE__); + exit(-1); + } } -void RmsNorm::forward(const float16_t *input, float16_t *output, int rows, int iStride, int oStride, float epsilon) { +template +void RmsNormImp::forward( + const float16_t *input, float16_t *output, int rows, int iStride, int oStride, float epsilon) { TimeLine t("RmsNorm.forward"); - rmsNorm(output, input, weight, rows, normSize, iStride, oStride, epsilon); + if constexpr (std::is_same_v) { + rmsNorm(output, input, weight, rows, normSize, iStride, oStride, epsilon); + } else { + printf("%s:%d: Could not forward in RmsNorm with undefined data type.\n", __FILE__, __LINE__); + exit(-1); + } } +#endif + +template class RmsNormImp; +template class RmsNormImp; +template class RmsNormImp; } // namespace xft \ No newline at end of file diff --git a/src/layers/rms_norm.h b/src/layers/rms_norm.h index 31dc0b06..b78a3ae0 100644 --- a/src/layers/rms_norm.h +++ b/src/layers/rms_norm.h @@ -15,15 +15,18 @@ #pragma once #include "bfloat16.h" +#include "transformer_ctx.h" #include "weight_util.h" namespace xft { // RMS normalization: only support the norm along last dimension -class RmsNorm { +template +class RmsNormImp { public: - RmsNorm(); - ~RmsNorm(); + RmsNormImp(); + RmsNormImp(DecoderContext *ctx); + ~RmsNormImp(); void setWeight(const float *w, const float *, int cols); void setWeight(const std::string &modelPath, const std::string &, int cols); @@ -39,8 +42,8 @@ class RmsNorm { void forward(const bfloat16_t *input, bfloat16_t *output, int rows, int iStride = -1, int oStride = -1, float epsilon = 1e-6); - void forward(const float *input, float16_t *output, int rows, int iStride = -1, int oStride = -1, - float epsilon = 1e-6); + void forward( + const float *input, float16_t *output, int rows, int iStride = -1, int oStride = -1, float epsilon = 1e-6); void forward(const float16_t *input, float16_t *output, int rows, int iStride = -1, int oStride = -1, float epsilon = 1e-6); @@ -49,7 +52,14 @@ class RmsNorm { int normSize; // the scale weight - float *weight = nullptr; + T *weight = nullptr; + void *device = nullptr; }; +#ifdef XFT_GPU +using RmsNorm = RmsNormImp; +#else +using RmsNorm = RmsNormImp; +#endif + } // namespace xft \ No newline at end of file diff --git a/src/layers/rotary_embedding.cpp b/src/layers/rotary_embedding.cpp index 60369508..5efc6020 100644 --- a/src/layers/rotary_embedding.cpp +++ b/src/layers/rotary_embedding.cpp @@ -22,6 +22,7 @@ LlamaRotaryEmbedding::LlamaRotaryEmbedding(DecoderContext *ctx) { const std::string emb_cos_str = "emb_cos"; const std::string emb_sin_str = "emb_sin"; + this->device = ctx->device; this->dim = ctx->attHeadSize; this->max_position_embeddings = ctx->maxPosEmbed; ctx->GetAttr("rope_theta", &this->base, 10000); @@ -46,6 +47,20 @@ LlamaRotaryEmbedding::LlamaRotaryEmbedding(DecoderContext *ctx) { printf("Incorrect dim=%d, inv_freq_size=%d\n", dim, inv_freq_size); exit(-1); } + +#ifdef XFT_GPU + if (this->device != nullptr) { + float *emb_cos_bak = emb_cos; + float *emb_sin_bak = emb_sin; + emb_cos = ctx->getBuffer(emb_cos_str + "_gpu", max_position_embeddings * inv_freq_size, device); + emb_sin = ctx->getBuffer(emb_sin_str + "_gpu", max_position_embeddings * inv_freq_size, device); + if (!ctx->cached(inv_freq_str + "_gpu")) { + inv_freq = ctx->getBuffer(inv_freq_str + "_gpu", inv_freq_size); + xft::memcopy(emb_cos, emb_cos_bak, max_position_embeddings * inv_freq_size * sizeof(float), device); + xft::memcopy(emb_sin, emb_sin_bak, max_position_embeddings * inv_freq_size * sizeof(float), device); + } + } +#endif } // This API is deprecated, will delete after all rotary embed code refactor. @@ -90,6 +105,28 @@ LlamaRotaryEmbedding::LlamaRotaryEmbedding(const int dim, const int max_position // |_____| |_____| // head_size/2 head_size/2 +#ifdef XFT_GPU + +void LlamaRotaryEmbedding::forward( + float *query, float *key, int qStride, int kStride, const int *qkShape, const int *positionIds) { + xft::llamaApplyRotaryPosEmbeding(this->device, + query, key, qStride, kStride, emb_cos, emb_sin, inv_freq_size, qkShape, positionIds); +} + +void LlamaRotaryEmbedding::forward( + bfloat16_t *query, bfloat16_t *key, int qStride, int kStride, const int *qkShape, const int *positionIds) { + xft::llamaApplyRotaryPosEmbeding(this->device, + query, key, qStride, kStride, emb_cos, emb_sin, inv_freq_size, qkShape, positionIds); +} + +void LlamaRotaryEmbedding::forward( + float16_t *query, float16_t *key, int qStride, int kStride, const int *qkShape, const int *positionIds) { + xft::llamaApplyRotaryPosEmbeding(this->device, + query, key, qStride, kStride, emb_cos, emb_sin, inv_freq_size, qkShape, positionIds); +} + +#else + void LlamaRotaryEmbedding::forward( float *query, float *key, int qStride, int kStride, const int *qkShape, const int *positionIds) { int dim = inv_freq_size * 2; @@ -138,14 +175,18 @@ void LlamaRotaryEmbedding::forward( void LlamaRotaryEmbedding::forward( bfloat16_t *query, bfloat16_t *key, int qStride, int kStride, const int *qkShape, const int *positionIds) { - xft::llamaApplyRotaryPosEmbeding(query, key, qStride, kStride, emb_cos, emb_sin, inv_freq_size, qkShape, positionIds); + xft::llamaApplyRotaryPosEmbeding( + query, key, qStride, kStride, emb_cos, emb_sin, inv_freq_size, qkShape, positionIds); } void LlamaRotaryEmbedding::forward( float16_t *query, float16_t *key, int qStride, int kStride, const int *qkShape, const int *positionIds) { - xft::llamaApplyRotaryPosEmbeding(query, key, qStride, kStride, emb_cos, emb_sin, inv_freq_size, qkShape, positionIds); + xft::llamaApplyRotaryPosEmbeding( + query, key, qStride, kStride, emb_cos, emb_sin, inv_freq_size, qkShape, positionIds); } +#endif // GPU + // For continuous batching void LlamaRotaryEmbedding::forward( float *query, float *key, int totSeqLen, int qStride, int kStride, int qHeads, int kHeads, int *positionIds) { diff --git a/src/layers/rotary_embedding.h b/src/layers/rotary_embedding.h index 1b96d033..9142d129 100644 --- a/src/layers/rotary_embedding.h +++ b/src/layers/rotary_embedding.h @@ -66,4 +66,5 @@ class LlamaRotaryEmbedding { float *inv_freq = nullptr; float *emb_cos = nullptr; float *emb_sin = nullptr; + void *device = nullptr; }; diff --git a/src/layers/token_embedding.h b/src/layers/token_embedding.h index f49a135e..c49dd597 100644 --- a/src/layers/token_embedding.h +++ b/src/layers/token_embedding.h @@ -24,6 +24,7 @@ class TokenEmbedding { TokenEmbedding(DecoderContext *ctx) { this->vocabSize = ctx->vocabSize; this->hiddenSize = ctx->hiddenSize; + this->device = ctx->device; } void setWeights(float *tokenEmb) { @@ -59,4 +60,5 @@ class TokenEmbedding { int hiddenSize; T *embTable = nullptr; + void *device = nullptr; }; diff --git a/src/models/common_decoder.h b/src/models/common_decoder.h index 8ec4fe22..d463a899 100644 --- a/src/models/common_decoder.h +++ b/src/models/common_decoder.h @@ -155,7 +155,7 @@ class CommonDecoder : public AbstractDecoder { public: CommonDecoder(const std::string &modelPath, const std::string &modelType) : messenger(Messenger::getInstance()) -#ifdef DEBUG +#ifdef XFT_DEBUG , dbg("model_decoder.csv") #endif { @@ -256,7 +256,7 @@ class CommonDecoder : public AbstractDecoder { virtual ~CommonDecoder() { if (this->inputTokens) free(this->inputTokens); - if (this->attnMask) free(this->attnMask); + if (this->attnMask) xft::dealloc(this->attnMask); delete this->decoderBlock; delete this->predictor; @@ -313,7 +313,7 @@ class CommonDecoder : public AbstractDecoder { this->embeddingForward(ids, embBuf, batchSize * inputSeqLen); this->accSeqLen += seqLen; -#ifdef DEBUG +#ifdef XFT_DEBUG dbg.debugPrint("---- embedding.forward ----\n"); dbg.debugPrint("ids:\n"); dbg.dumpMatrix(ids, batchSize, inputSeqLen, inputSeqLen); @@ -342,21 +342,22 @@ class CommonDecoder : public AbstractDecoder { // TODO: Error: different scope when dynamic loading so file // this->messenger.worldRecvFP32(embBuf, count, prev_world_rank, curr_world_rank); if (!SequencePool::getInstance().has(sequenceID)) { - auto *seqs = SequencePool::getInstance().newMeta(sequenceID, seqLen); - seqs->get(0)->setPastSeqLen(pastSeqLen); - seqs->get(0)->allocBuffer(hiddenSize, embBuf); - SequencePool::getInstance().add(seqs->get(0)->getSequenceID(), seqs); + auto *groupMeta = SequencePool::getInstance().newGroupMeta(sequenceID, seqLen); + groupMeta->get(0)->setPastSeqLen(pastSeqLen); + groupMeta->get(0)->allocBuffer(hiddenSize, embBuf); + SequencePool::getInstance().add(groupMeta); } TaskWaitingQueue::getInstance().push(SequencePool::getInstance().get(sequenceID)); } if (!InputQueue::getInstance().empty()) { if (!TaskWaitingQueue::getInstance().isFull()) { - auto *seqs = InputQueue::getInstance().pop(); - seqs->get(0)->setPastSeqLen(pastSeqLen); - seqs->get(0)->allocBuffer(hiddenSize, embBuf); - SequencePool::getInstance().add(seqs->get(0)->getSequenceID(), seqs); - TaskWaitingQueue::getInstance().push(SequencePool::getInstance().get(seqs->get(0)->getSequenceID())); + auto *groupMeta = InputQueue::getInstance().pop(); + groupMeta->get(0)->setPastSeqLen(pastSeqLen); + groupMeta->get(0)->allocBuffer(hiddenSize, embBuf); + SequencePool::getInstance().add(groupMeta); + TaskWaitingQueue::getInstance().push( + SequencePool::getInstance().get(groupMeta->get(0)->getSequenceID())); } } @@ -370,6 +371,16 @@ class CommonDecoder : public AbstractDecoder { TimeLine t("Decoder.Seq" + std::to_string(sequenceID) + ".Step"); #endif +#ifdef XFT_GPU + size_t embBufSize = batchSize * inputSeqLen * hiddenSize * sizeof(AttnInT); + AttnInT *embBufTmp = (AttnInT *)xft::alloc(embBufSize, ctx->device); + AttnInT *outBufTmp = (AttnInT *)xft::alloc( + actBuffers->Rows() * actBuffers->Cols() * sizeof(float) - embBufSize, ctx->device); + xft::memcopy(embBufTmp, embBuf, embBufSize, ctx->device); + embBuf = embBufTmp; + outBuf = outBufTmp; +#endif + // Decoder: forward int layers_per_pp_stage = decoderBlock->size(); for (int i = 0; i < layers_per_pp_stage; ++i) { @@ -446,20 +457,21 @@ class CommonDecoder : public AbstractDecoder { lnIn = outBuf; #pragma omp parallel for for (int b = 0; b < batchSize; ++b) { - memcpy(lnIn + b * hiddenSize, embBuf + ((b + 1) * inputSeqLen - 1) * hiddenSize, - hiddenSize * sizeof(MlpOutT)); + xft::memcopy(lnIn + b * hiddenSize, embBuf + ((b + 1) * inputSeqLen - 1) * hiddenSize, + hiddenSize * sizeof(MlpOutT), ctx->device ? ctx->device : nullptr); } } -#ifdef DEBUG +#ifdef XFT_DEBUG dbg.debugPrint(">>> DecoderLayer Output[%d, %d] (%d):\n", batchSize * inputSeqLen, hiddenSize, hiddenSize); dbg.dumpMatrix(embBuf, batchSize * inputSeqLen, hiddenSize, hiddenSize); dbg.debugPrint("LayerNorm In:\n"); - if (!logitsAll) + if (!logitsAll) { dbg.dumpMatrix(lnIn, batchSize, hiddenSize, hiddenSize); - else + } else { dbg.dumpMatrix(lnIn, batchSize * inputSeqLen, hiddenSize, hiddenSize); + } #endif // LN, as it supports inplace computing, input and output can be the same @@ -469,33 +481,44 @@ class CommonDecoder : public AbstractDecoder { else lastLayerNormForward(lnIn, lnOut, batchSize * seqLen); -#ifdef DEBUG +#ifdef XFT_DEBUG dbg.debugPrint("LayerNorm Out:\n"); - if (!logitsAll) + if (!logitsAll) { dbg.dumpMatrix(lnOut, batchSize, hiddenSize, hiddenSize); - else + } else { dbg.dumpMatrix(lnOut, batchSize * inputSeqLen, hiddenSize, hiddenSize); + } #endif // Predictor + const int splitSize = this->predictor->getSplitSize(); float *finalOut = (float *)outBuf; if (!logitsAll) this->predictor->forward(ctx, lnOut, finalOut, batchSize); else this->predictor->forward(ctx, lnOut, finalOut, batchSize * seqLen); -#ifdef DEBUG - auto splitSize = this->predictor->getSplitSize(); +#ifdef XFT_DEBUG dbg.debugPrint("finalOut:\n"); - if (!logitsAll) + if (!logitsAll) { dbg.dumpMatrix(finalOut, batchSize, splitSize, splitSize); - else + } else { dbg.dumpMatrix(finalOut, batchSize * inputSeqLen, splitSize, splitSize); + } +#endif + +#ifdef XFT_GPU + xft::dealloc(embBuf, ctx->device); + embBuf = (AttnInT *)actBuffers->Data(); + + float *finalOutTmp = (float *)(embBuf + batchSize * inputSeqLen * hiddenSize); + xft::memcopy(finalOutTmp, finalOut, batchSize * splitSize * sizeof(float), ctx->device); + xft::dealloc(outBuf, ctx->device); + finalOut = finalOutTmp; #endif // Expand the result to make it cover multiple beams if (step == 0 && beamSize > 1) { - const int splitSize = this->predictor->getSplitSize(); for (int b = userSideBS - 1; b >= 0; --b) { float *src = finalOut + b * splitSize; #pragma omp parallel for @@ -515,7 +538,7 @@ class CommonDecoder : public AbstractDecoder { } std::tuple forward(std::vector &seqs, bool logitsAll = false) { - // Assume all sequences are all prompts(step==0) or all decodes(step>0) + // Assume all sequences are all prompts(step==0) or all decodes(step>0) // Assume input has been synced with master in higher level. TimeLine t("Decoder.forward"); TimeLine t1("Decoder.embedding"); @@ -562,7 +585,7 @@ class CommonDecoder : public AbstractDecoder { } } -#ifdef DEBUG +#ifdef XFT_DEBUG dbg.debugPrint(">>> DecoderLayer Output[%d, %d] (%d):\n", logitRows, hiddenSize, hiddenSize); dbg.dumpMatrix(embBuf, logitRows, hiddenSize, hiddenSize); dbg.debugPrint("LayerNorm In:\n"); @@ -574,7 +597,7 @@ class CommonDecoder : public AbstractDecoder { MlpOutT *lnOut = embBuf; lastLayerNormForward(lnIn, lnOut, logitRows); -#ifdef DEBUG +#ifdef XFT_DEBUG dbg.debugPrint("LayerNorm Out:\n"); dbg.dumpMatrix(lnOut, logitRows, hiddenSize, hiddenSize); #endif @@ -583,7 +606,7 @@ class CommonDecoder : public AbstractDecoder { float *finalOut = (float *)outBuf; this->predictor->forward(ctx, lnOut, finalOut, logitRows); -#ifdef DEBUG +#ifdef XFT_DEBUG auto splitSize = this->predictor->getSplitSize(); dbg.debugPrint("finalOut:\n"); dbg.dumpMatrix(finalOut, logitRows, splitSize, splitSize); @@ -739,14 +762,20 @@ class CommonDecoder : public AbstractDecoder { exit(-1); } } else { - this->context.reset(new DecoderContext(layers, hiddenSize, headSize, attHeadNum, kvHeadNum, imSize, act, - epsilon, vocabSize, embeddingSize, maxPositions, maxPosEmbed, maxSeqLength, tpRank, tpSize, ppSize, - ppRank, ropeParamsPtr, useLogN, useNTK)); - + int engineIdx = 0; if (env.getEngineKind() == xft::DeviceKind::iGPU && env.getEngineIndex() < 0) // Sequential assignment - this->context->mmHelper = new MMHelper(env.getEngineKind(), ppRank * tpSize + tpRank); + engineIdx = ppRank * tpSize + tpRank; else // assignment through the user - this->context->mmHelper = new MMHelper(env.getEngineKind(), env.getEngineIndex()); + engineIdx = env.getEngineIndex(); + + this->mmHelper.reset(new MMHelper(env.getEngineKind(), engineIdx)); +#ifdef XFT_GPU + auto devices = sycl::device::get_devices(sycl::info::device_type::gpu); + this->device.reset(new sycl::queue(devices[this->mmHelper->getEngineCount() + engineIdx])); +#endif + this->context.reset(new DecoderContext(layers, hiddenSize, headSize, attHeadNum, kvHeadNum, imSize, act, + epsilon, vocabSize, embeddingSize, maxPositions, maxPosEmbed, maxSeqLength, tpRank, tpSize, + this->mmHelper.get(), this->device.get(), ppSize, ppRank, ropeParamsPtr, useLogN, useNTK)); } return this->context.get(); @@ -765,7 +794,7 @@ class CommonDecoder : public AbstractDecoder { int kvSize = attHeadSize * kvHeadNum; int qkvSize = qSize + 2 * kvSize; -#define ALLOC(size, alignment) xft::alloc((size), (alignment)) +#define ALLOC(size, alignment) xft::alloc((size), nullptr, (alignment)) OriWeiT *qkvWeight = (OriWeiT *)ALLOC(hiddenSize * qkvSize * sizeof(OriWeiT), 64); float *qkvScales = nullptr; float *qkvZeros = nullptr; @@ -1067,6 +1096,8 @@ class CommonDecoder : public AbstractDecoder { // Execution context std::shared_ptr context; + std::shared_ptr mmHelper; + std::shared_ptr device; // The initial input sequence length, which is the prompt token size int initSeqLen; @@ -1105,7 +1136,7 @@ class CommonDecoder : public AbstractDecoder { int startId; int endId; -#ifdef DEBUG +#ifdef XFT_DEBUG Debugger dbg; #endif }; diff --git a/src/models/llama.cpp b/src/models/llama.cpp index 2b7b6195..70fb155e 100644 --- a/src/models/llama.cpp +++ b/src/models/llama.cpp @@ -31,12 +31,14 @@ LlamaLLM::LlamaLLM(const std::string &modelPath) setEmbeddingWeights(modelPath); // Final LN + finalLN = new RmsNorm(ctx); setFinalLnWeight(modelPath); } template LlamaLLM::~LlamaLLM() { delete embedding; + delete finalLN; } template @@ -46,7 +48,7 @@ void LlamaLLM::setEmbeddingWeights(const std::string &modelPath) template void LlamaLLM::setFinalLnWeight(const std::string &modelPath) { - finalLN.setWeight(modelPath + "/model.final_layernorm.weight.bin", "", embedding->getHiddenSize()); + finalLN->setWeight(modelPath + "/model.final_layernorm.weight.bin", "", embedding->getHiddenSize()); } // Prepare attention_mask which is like: @@ -121,17 +123,17 @@ void LlamaLLM::embeddingForward(int *ids, float16_t *output, int template void LlamaLLM::lastLayerNormForward(float *input, float *output, int rows) { - finalLN.forward(input, output, rows); + finalLN->forward(input, output, rows); } template void LlamaLLM::lastLayerNormForward(bfloat16_t *input, bfloat16_t *output, int rows) { - finalLN.forward(input, output, rows); + finalLN->forward(input, output, rows); } template void LlamaLLM::lastLayerNormForward(float16_t *input, float16_t *output, int rows) { - finalLN.forward(input, output, rows); + finalLN->forward(input, output, rows); } IMPLEMENT_MODEL(LlamaLLM, llama) \ No newline at end of file diff --git a/src/models/llama.h b/src/models/llama.h index 5eea8a06..9bcea5aa 100644 --- a/src/models/llama.h +++ b/src/models/llama.h @@ -48,7 +48,7 @@ class LlamaLLM private: TokenEmbedding *embedding; - RmsNorm finalLN; + RmsNorm *finalLN; }; REGISTER_MODEL(LlamaLLM, llama) \ No newline at end of file diff --git a/src/utils/compile_util.h b/src/utils/compile_util.h index e11cf32c..e77e8f6b 100644 --- a/src/utils/compile_util.h +++ b/src/utils/compile_util.h @@ -17,6 +17,10 @@ #include #include +#ifdef XFT_GPU +#include +#endif + #define likely(x) __builtin_expect((x), 1) #define unlikely(x) __builtin_expect((x), 0) diff --git a/src/utils/debugger.h b/src/utils/debugger.h index ca74d115..ce130393 100644 --- a/src/utils/debugger.h +++ b/src/utils/debugger.h @@ -114,6 +114,15 @@ class Debugger { } } +#ifdef XFT_GPU + template + void dumpMatrix(xft::Matrix &m, bool print_all = false) { + } + + template + void dumpMatrix(T *data, uint64_t rows, uint64_t cols, uint64_t stride, bool print_all = false) { + } +#else template void dumpMatrix(xft::Matrix &m, bool print_all = false) { std::ostringstream oss; @@ -281,7 +290,7 @@ class Debugger { fflush(debugFile); } } - +#endif // Function to store float* data to a file template void storeMatrix(const std::string &filename, const T *data, uint64_t rows, uint64_t cols) { diff --git a/src/utils/decoder_util.h b/src/utils/decoder_util.h index 08289a16..4c8dba1c 100644 --- a/src/utils/decoder_util.h +++ b/src/utils/decoder_util.h @@ -28,6 +28,10 @@ #include "transformer_ctx.h" #include "xdnn.h" +#ifdef XFT_GPU +#include +#endif + extern int getFlashThresh(); extern bool enableCATMLP(); extern bool enableSkipMsk(); @@ -444,9 +448,41 @@ class DecoderUtil { return std::make_pair(maxVal, sum); } +#ifdef XFT_GPU + template + static void siluSum(xft::Matrix &src, xft::Matrix &dst, void *device = nullptr) { + int M = src.Rows(); + int lds = src.Stride(); + int N = lds / 2; + int ldd = dst.Stride(); + + if (device != nullptr) { + sycl::queue *gpu_queue = static_cast(device); + + if constexpr (std::is_same_v && std::is_same_v) { + sycl::half *src0 = (sycl::half *)src.Data(); + sycl::half *src1 = (sycl::half *)(src.Data() + N); + sycl::half *dest = (sycl::half *)dst.Data(); + + gpu_queue + ->submit([&](sycl::handler &h) { + h.parallel_for(M * N, [=](auto i) { + int32_t row = i / N; + int32_t col = i % N; + sycl::half tmp0 = src0[row * lds + col]; + sycl::half tmp1 = src1[row * lds + col]; + dest[row * ldd + col] = tmp0 * tmp1 + / ((sycl::half)1.0f + (sycl::half)sycl::native::exp(tmp0 * -1.0f)); + }); + }) + .wait(); + } + } + } +#else // compute silu on the left half and then add it with the right half template - static void siluSum(xft::Matrix &src, xft::Matrix &dst) { + static void siluSum(xft::Matrix &src, xft::Matrix &dst, void *device = nullptr) { __m512 one = _mm512_set1_ps(1.f); __m512 negOne = _mm512_set1_ps(-1.f); int M = src.Rows(); @@ -469,10 +505,11 @@ class DecoderUtil { } } } +#endif // compute gelu on the left half and then add it with the right half template - static void geluSum(xft::Matrix &src, xft::Matrix &dst) { + static void geluSum(xft::Matrix &src, xft::Matrix &dst, void *device = nullptr) { const __m512 c1 = _mm512_set1_ps(0.044715f); const __m512 c2 = _mm512_set1_ps(0.7978845608f); const __m512 vone = _mm512_set1_ps(1.0f); diff --git a/src/utils/matmul_helper.h b/src/utils/matmul_helper.h index 988f9e43..536160a0 100644 --- a/src/utils/matmul_helper.h +++ b/src/utils/matmul_helper.h @@ -20,6 +20,7 @@ #include "dtype.h" #include "environment.h" #include "float16.h" +#include "intrinsics_util.h" #include "my_types.h" #include "normal_float4x2.h" #include "oneapi/dnnl/dnnl.hpp" @@ -53,11 +54,21 @@ class MMHelper { } AMXThresholdM = Env::getInstance().getAMXThresholdM(); + cpu_engine = new dnnl::engine(dnnl::engine::kind::cpu, 0); + cpu_stream = new dnnl::stream(*cpu_engine); } ~MMHelper() { if (engine) delete engine; if (stream) delete stream; + + for (auto &pair : matmul_hub) { + dnnl::matmul::primitive_desc *primitive_desc_ptr = std::get<0>(pair.second); + dnnl::matmul *matmul_ptr = std::get<1>(pair.second); + + delete primitive_desc_ptr; + delete matmul_ptr; + } } // Pack the MatMul weight from 'src(rows, cols)' to 'weight' @@ -215,8 +226,8 @@ class MMHelper { int offset = trans ? rowOffset : colOffset; scaleWeight.Resize(size); zeroWeight.Resize(size); - memcpy(scaleWeight.Data(), scales + offset, size * sizeof(float)); - memcpy(zeroWeight.Data(), zeros + offset, size * sizeof(float)); + if (scales) memcpy(scaleWeight.Data(), scales + offset, size * sizeof(float)); + if (zeros) memcpy(zeroWeight.Data(), zeros + offset, size * sizeof(float)); #pragma omp parallel for for (uint64_t i = 0; i < rowSize; i++) { WeiT *dst = convertedWeight.Data() + i * convertedWeight.Stride(); @@ -231,8 +242,8 @@ class MMHelper { int offset = trans ? rowOffset : colOffset; scaleWeight.Resize(size); zeroWeight.Resize(size); - memcpy(scaleWeight.Data(), scales + offset, size * sizeof(float)); - memcpy(zeroWeight.Data(), zeros + offset, size * sizeof(float)); + if (scales) memcpy(scaleWeight.Data(), scales + offset, size * sizeof(float)); + if (zeros) memcpy(zeroWeight.Data(), zeros + offset, size * sizeof(float)); #pragma omp parallel for for (uint64_t i = 0; i < rowSize; i++) { WeiT *dst = convertedWeight.Data() + i * convertedWeight.Stride() / 2; @@ -380,8 +391,9 @@ class MMHelper { // W8A8 else if constexpr (std::is_same_v) { using dt = dnnl::memory::data_type; + auto tag = trans ? dnnl::memory::format_tag::ba : dnnl::memory::format_tag::ab; - dnnl::memory B_mem({{K, N}, dt::s8, tag}, *this->engine, src.Data()); + dnnl::memory B_mem({{K, N}, dt::s8, tag}, *cpu_engine, src.Data()); dnnl::memory::desc desc({K, N}, dt::s8, get_onednn_weight_layout(dt::s8)); // When converting to oneDNN blocked memory format, padded dims can be larger than [K, N] @@ -391,9 +403,9 @@ class MMHelper { weight.Resize(dims[0], dims[1]); weight.Resize(K, N); - dnnl::memory packedB_mem(desc, *engine, weight.Data()); - dnnl::reorder(B_mem, packedB_mem).execute(*stream, B_mem, packedB_mem); - stream->wait(); + dnnl::memory packedB_mem(desc, *cpu_engine, weight.Data()); + dnnl::reorder(B_mem, packedB_mem).execute(*cpu_stream, B_mem, packedB_mem); + cpu_stream->wait(); } // INT4 @@ -427,6 +439,34 @@ class MMHelper { } } + template + void transposeWeight(bool trans, xft::Matrix &src, xft::Matrix &dst) { + using namespace dnnl; + using tag = memory::format_tag; + using dt = memory::data_type; + + dt weight_dt; + if constexpr (std::is_same_v) { + weight_dt = dt::f32; + } else if constexpr (std::is_same_v) { + weight_dt = dt::bf16; + } else if constexpr (std::is_same_v) { + weight_dt = dt::f16; + } else { + printf(">>> onednn_gemm_compute: input date type not supported."); + exit(-1); + } + + int K = trans ? src.Cols() : src.Rows(); + int N = trans ? src.Rows() : src.Cols(); + auto weight_md = memory::desc({K, N}, weight_dt, trans ? tag::ba : tag::ab); + auto weight_mem = memory(weight_md, *cpu_engine, src.Data()); + auto transposed_weight_md = memory::desc({K, N}, weight_dt, get_onednn_weight_layout(weight_dt)); + auto transposed_weight_mem = memory(transposed_weight_md, *cpu_engine, dst.Data()); + dnnl::reorder(weight_mem, transposed_weight_mem).execute(*cpu_stream, weight_mem, transposed_weight_mem); + cpu_stream->wait(); + } + template void compute(bool transA, int M, int N, int K, float alpha, const InT *A, int lda, const WeiT *packedB, const float *scaleB, const float *zeroB, const float *sumB, float beta, OutT *C, int ldc) { @@ -439,9 +479,14 @@ class MMHelper { // FP16 else if constexpr (std::is_same_v) { #ifdef AVX512_FP32_WEIGHT_ONLY_FP16 - GEMMVERBOSE("xdnn_sgemm_f32f16f32_compute", - xdnn_sgemm_f32f16f32_compute( - transA, M, N, K, alpha, A, lda, (const XDNN_FP16 *)packedB, beta, C, ldc)); + if constexpr (std::is_same_v && std::is_same_v) { + GEMMVERBOSE("xdnn_sgemm_f32f16f32_compute", + xdnn_sgemm_f32f16f32_compute( + transA, M, N, K, alpha, A, lda, (const XDNN_FP16 *)packedB, beta, C, ldc)); + } else { + GEMMVERBOSE("onednn_gemm_compute", + onednn_gemm_compute(transA, M, N, K, alpha, A, lda, packedB, beta, C, ldc)); + } #elif defined(AVX512_FP16_WEIGHT_ONLY_FP16) if constexpr (std::is_same_v) { GEMMVERBOSE("xdnn_hgemm_f32f16f32_compute", @@ -469,7 +514,7 @@ class MMHelper { #ifdef AVX512_FP32_WEIGHT_ONLY_BF16 GEMMVERBOSE("xdnn_sgemm_f32bf16f32_compute", xdnn_sgemm_f32bf16f32_compute( - transA, M, N, K, alpha, A, lda, (const XDNN_UINT4x2 *)packedB, beta, C, ldc)); + transA, M, N, K, alpha, A, lda, (const XDNN_BF16 *)packedB, beta, C, ldc)); #elif defined(AVX512_BF16_WEIGHT_ONLY_BF16) // TODO: xdnn impl? if constexpr (std::is_same_v) { @@ -558,9 +603,15 @@ class MMHelper { // FP16 else if constexpr (std::is_same_v) { #ifdef AVX512_FP32_WEIGHT_ONLY_FP16 - GEMMVERBOSE("xdnn_sgemm_f32f16f32_compute_biasadd", - xdnn_sgemm_f32f16f32_compute_biasadd( - transA, M, N, K, alpha, A, lda, (const XDNN_FP16 *)packedB, beta, C, ldc, bias)); + if constexpr (std::is_same_v && std::is_same_v) { + GEMMVERBOSE("xdnn_sgemm_f32f16f32_compute_biasadd", + xdnn_sgemm_f32f16f32_compute_biasadd( + transA, M, N, K, alpha, A, lda, (const XDNN_FP16 *)packedB, beta, C, ldc, bias)); + } else { + GEMMVERBOSE("onednn_gemm_compute_bias", + onednn_gemm_compute(transA, M, N, K, alpha, A, lda, packedB, beta, C, ldc, bias, + (const InT *)nullptr, -1, matmul_kinds::BiasAdd)); + } #elif defined(AVX512_FP16_WEIGHT_ONLY_FP16) if constexpr (std::is_same_v) { GEMMVERBOSE("xdnn_hgemm_f32f16f32_compute_biasadd", @@ -588,7 +639,7 @@ class MMHelper { #ifdef AVX512_FP32_WEIGHT_ONLY_BF16 GEMMVERBOSE("xdnn_sgemm_f32bf16f32_compute_biasadd", xdnn_sgemm_f32bf16f32_compute_biasadd( - transA, M, N, K, alpha, A, lda, (const XDNN_UINT4x2 *)packedB, beta, C, ldc, bias)); + transA, M, N, K, alpha, A, lda, (const XDNN_BF16 *)packedB, beta, C, ldc, bias)); #elif defined(AVX512_BF16_WEIGHT_ONLY_BF16) // TODO: xdnn impl? if constexpr (std::is_same_v) { @@ -677,12 +728,19 @@ class MMHelper { GEMMVERBOSE("xdnn_sgemm_compute_biasadd_relu", xdnn_sgemm_compute_biasadd_relu(transA, M, N, K, alpha, A, lda, packedB, beta, C, ldc, bias)); } + // FP16 else if constexpr (std::is_same_v) { #ifdef AVX512_FP32_WEIGHT_ONLY_FP16 - GEMMVERBOSE("xdnn_sgemm_f32f16f32_compute_biasadd_relu", - xdnn_sgemm_f32f16f32_compute_biasadd_relu( - transA, M, N, K, alpha, A, lda, (const XDNN_FP16 *)packedB, beta, C, ldc, bias)); + if constexpr (std::is_same_v && std::is_same_v) { + GEMMVERBOSE("xdnn_sgemm_f32f16f32_compute_biasadd_relu", + xdnn_sgemm_f32f16f32_compute_biasadd_relu( + transA, M, N, K, alpha, A, lda, (const XDNN_FP16 *)packedB, beta, C, ldc, bias)); + } else { + GEMMVERBOSE("onednn_gemm_compute_bias_relu", + onednn_gemm_compute(transA, M, N, K, alpha, A, lda, packedB, beta, C, ldc, bias, + (const InT *)nullptr, -1, matmul_kinds::BiasAdd_Relu)); + } #elif defined(AVX512_FP16_WEIGHT_ONLY_FP16) if constexpr (std::is_same_v) { GEMMVERBOSE("xdnn_hgemm_f32f16f32_compute_biasadd_relu", @@ -710,7 +768,7 @@ class MMHelper { #ifdef AVX512_FP32_WEIGHT_ONLY_BF16 GEMMVERBOSE("xdnn_sgemm_f32bf16f32_compute_biasadd_relu", xdnn_sgemm_f32bf16f32_compute_biasadd_relu( - transA, M, N, K, alpha, A, lda, (const XDNN_UINT4x2 *)packedB, beta, C, ldc, bias)); + transA, M, N, K, alpha, A, lda, (const XDNN_BF16 *)packedB, beta, C, ldc, bias)); #elif defined(AVX512_BF16_WEIGHT_ONLY_BF16) if (M > AMXThresholdM) { GEMMVERBOSE("onednn_amx_sgemm_f32bf16f32_compute_biasadd_relu", @@ -795,9 +853,15 @@ class MMHelper { // FP16 else if constexpr (std::is_same_v) { #ifdef AVX512_FP32_WEIGHT_ONLY_FP16 - GEMMVERBOSE("xdnn_sgemm_f32f16f32_compute_silu", - xdnn_sgemm_f32f16f32_compute_silu( - transA, M, N, K, alpha, A, lda, (const XDNN_FP16 *)packedB, beta, C, ldc)); + if constexpr (std::is_same_v && std::is_same_v) { + GEMMVERBOSE("xdnn_sgemm_f32f16f32_compute_silu", + xdnn_sgemm_f32f16f32_compute_silu( + transA, M, N, K, alpha, A, lda, (const XDNN_FP16 *)packedB, beta, C, ldc)); + } else { + GEMMVERBOSE("onednn_gemm_compute_silu", + onednn_gemm_compute(transA, M, N, K, alpha, A, lda, packedB, beta, C, ldc, + (const float *)nullptr, (const InT *)nullptr, -1, matmul_kinds::Silu)); + } #elif defined(AVX512_FP16_WEIGHT_ONLY_FP16) if constexpr (std::is_same_v) { GEMMVERBOSE("xdnn_hgemm_f32f16f32_compute_silu", @@ -825,7 +889,7 @@ class MMHelper { #ifdef AVX512_FP32_WEIGHT_ONLY_BF16 GEMMVERBOSE("xdnn_sgemm_f32bf16f32_compute_silu", xdnn_sgemm_f32bf16f32_compute_silu( - transA, M, N, K, alpha, A, lda, (const XDNN_UINT4x2 *)packedB, beta, C, ldc)); + transA, M, N, K, alpha, A, lda, (const XDNN_BF16 *)packedB, beta, C, ldc)); #elif defined(AVX512_BF16_WEIGHT_ONLY_BF16) if constexpr (std::is_same_v) { GEMMVERBOSE("onednn_amx_sgemm_f32bf16f32_compute_silu", @@ -916,9 +980,16 @@ class MMHelper { // FP16 else if constexpr (std::is_same_v) { #ifdef AVX512_FP32_WEIGHT_ONLY_FP16 - GEMMVERBOSE("xdnn_sgemm_f32f16f32_compute_gelu", - xdnn_sgemm_f32f16f32_compute_gelu( - transA, M, N, K, alpha, A, lda, (const XDNN_FP16 *)packedB, beta, C, ldc)); + if constexpr (std::is_same_v && std::is_same_v) { + GEMMVERBOSE("xdnn_sgemm_f32f16f32_compute_gelu", + xdnn_sgemm_f32f16f32_compute_gelu( + transA, M, N, K, alpha, A, lda, (const XDNN_FP16 *)packedB, beta, C, ldc)); + } else { + GEMMVERBOSE("onednn_gemm_compute_gelu", + onednn_gemm_compute(transA, M, N, K, alpha, A, lda, packedB, beta, C, ldc, + (const float *)nullptr, (const InT *)nullptr, -1, matmul_kinds::Gelu)); + } + #elif defined(AVX512_FP16_WEIGHT_ONLY_FP16) if constexpr (std::is_same_v) { GEMMVERBOSE("xdnn_hgemm_f32f16f32_compute_gelu", @@ -946,7 +1017,7 @@ class MMHelper { #ifdef AVX512_FP32_WEIGHT_ONLY_BF16 GEMMVERBOSE("xdnn_sgemm_f32bf16f32_compute_gelu", xdnn_sgemm_f32bf16f32_compute_gelu( - transA, M, N, K, alpha, A, lda, (const XDNN_UINT4x2 *)packedB, beta, C, ldc)); + transA, M, N, K, alpha, A, lda, (const XDNN_BF16 *)packedB, beta, C, ldc)); #elif defined(AVX512_BF16_WEIGHT_ONLY_BF16) if constexpr (std::is_same_v) { GEMMVERBOSE("onednn_amx_sgemm_f32bf16f32_compute_gelu", @@ -1038,9 +1109,15 @@ class MMHelper { // FP16 else if constexpr (std::is_same_v) { #ifdef AVX512_FP32_WEIGHT_ONLY_FP16 - GEMMVERBOSE("xdnn_sgemm_f32f16f32_compute_resmul", - xdnn_sgemm_f32f16f32_compute_resmul( - transA, M, N, K, alpha, A, lda, (const XDNN_FP16 *)packedB, beta, C, ldc, res, ldres)); + if constexpr (std::is_same_v && std::is_same_v) { + GEMMVERBOSE("xdnn_sgemm_f32f16f32_compute_resmul", + xdnn_sgemm_f32f16f32_compute_resmul( + transA, M, N, K, alpha, A, lda, (const XDNN_FP16 *)packedB, beta, C, ldc, res, ldres)); + } else { + GEMMVERBOSE("onednn_gemm_compute_resmul", + onednn_gemm_compute(transA, M, N, K, alpha, A, lda, packedB, beta, C, ldc, + (const float *)nullptr, res, ldres, matmul_kinds::Resmul)); + } #elif defined(AVX512_FP16_WEIGHT_ONLY_FP16) if constexpr (std::is_same_v) { GEMMVERBOSE("xdnn_hgemm_f32f16f32_compute_resmul", @@ -1050,7 +1127,8 @@ class MMHelper { if constexpr (std::is_same_v) { GEMMVERBOSE("xdnn_hgemm_compute_resmul", xdnn_hgemm_compute_resmul(transA, M, N, K, alpha, (const XDNN_FP16 *)A, lda, - (const XDNN_FP16 *)packedB, beta, (XDNN_FP16 *)C, ldc, (const XDNN_FP16 *)res, ldres)); + (const XDNN_FP16 *)packedB, beta, (XDNN_FP16 *)C, ldc, (const XDNN_FP16 *)res, + ldres)); } else if constexpr (std::is_same_v) { GEMMVERBOSE("xdnn_hgemm_f16f16f32_compute_resmul", xdnn_hgemm_f16f16f32_compute_resmul(transA, M, N, K, alpha, (const XDNN_FP16 *)A, lda, @@ -1068,7 +1146,7 @@ class MMHelper { #ifdef AVX512_FP32_WEIGHT_ONLY_BF16 GEMMVERBOSE("xdnn_sgemm_f32bf16f32_compute_resmul", xdnn_sgemm_f32bf16f32_compute_resmul( - transA, M, N, K, alpha, A, lda, (const XDNN_UINT4x2 *)packedB, beta, C, ldc, res, ldres)); + transA, M, N, K, alpha, A, lda, (const XDNN_BF16 *)packedB, beta, C, ldc, res, ldres)); #elif defined(AVX512_BF16_WEIGHT_ONLY_BF16) if constexpr (std::is_same_v) { GEMMVERBOSE("onednn_amx_sgemm_f32bf16f32_compute_resmul", @@ -1161,9 +1239,15 @@ class MMHelper { // FP16 else if constexpr (std::is_same_v) { #ifdef AVX512_FP32_WEIGHT_ONLY_FP16 - GEMMVERBOSE("xdnn_sgemm_f32f16f32_compute_residential", - xdnn_sgemm_f32f16f32_compute_residential(transA, M, N, K, alpha, A, lda, (const XDNN_FP16 *)packedB, - beta, C, ldc, bias, res, ldres)); + if constexpr (std::is_same_v && std::is_same_v) { + GEMMVERBOSE("xdnn_sgemm_f32f16f32_compute_residential", + xdnn_sgemm_f32f16f32_compute_residential(transA, M, N, K, alpha, A, lda, + (const XDNN_FP16 *)packedB, beta, C, ldc, bias, res, ldres)); + } else { + GEMMVERBOSE("onednn_gemm_compute_residential", + onednn_gemm_compute(transA, M, N, K, alpha, A, lda, packedB, beta, C, ldc, bias, res, ldres, + matmul_kinds::Residential)); + } #elif defined(AVX512_FP16_WEIGHT_ONLY_FP16) if constexpr (std::is_same_v) { GEMMVERBOSE("xdnn_hgemm_f32f16f32_compute_residential", @@ -1173,7 +1257,8 @@ class MMHelper { if constexpr (std::is_same_v) { GEMMVERBOSE("xdnn_hgemm_compute_residential", xdnn_hgemm_compute_residential(transA, M, N, K, alpha, (const XDNN_FP16 *)A, lda, - (const XDNN_FP16 *)packedB, beta, (XDNN_FP16 *)C, ldc, bias, (const XDNN_FP16 *)res, ldres)); + (const XDNN_FP16 *)packedB, beta, (XDNN_FP16 *)C, ldc, bias, (const XDNN_FP16 *)res, + ldres)); } else if constexpr (std::is_same_v) { GEMMVERBOSE("xdnn_hgemm_f16f16f32_compute_residential", xdnn_hgemm_f16f16f32_compute_residential(transA, M, N, K, alpha, (const XDNN_FP16 *)A, lda, @@ -1191,7 +1276,7 @@ class MMHelper { #ifdef AVX512_FP32_WEIGHT_ONLY_BF16 GEMMVERBOSE("xdnn_sgemm_f32bf16f32_compute_residential", xdnn_sgemm_f32bf16f32_compute_residential(transA, M, N, K, alpha, A, lda, - (const XDNN_UINT4x2 *)packedB, beta, C, ldc, bias, res, ldres)); + (const XDNN_BF16 *)packedB, beta, C, ldc, bias, res, ldres)); #elif defined(AVX512_BF16_WEIGHT_ONLY_BF16) // TODO: xdnn impl? if constexpr (std::is_same_v) { @@ -1285,9 +1370,26 @@ class MMHelper { // FP16 else if constexpr (std::is_same_v) { #ifdef AVX512_FP32_WEIGHT_ONLY_FP16 - GEMMVERBOSE("xdnn_sgemm_f32f16f32_compute_resext", - xdnn_sgemm_f32f16f32_compute_resext(transA, M, N, K, alpha, A, lda, (const XDNN_FP16 *)packedB, - beta, C, ldc, bias, gamma, res, ldres)); + if constexpr (std::is_same_v && std::is_same_v) { + GEMMVERBOSE("xdnn_sgemm_f32f16f32_compute_resext", + xdnn_sgemm_f32f16f32_compute_resext(transA, M, N, K, alpha, A, lda, (const XDNN_FP16 *)packedB, + beta, C, ldc, bias, gamma, res, ldres)); + } else { +#pragma omp parallel for collapse(2) + for (uint64_t i = 0; i < M; ++i) { + for (int j = 0; j < N; ++j) { + auto remain = N - j; + __mmask16 mask = (remain >= 16 ? 0xffff : (1 << remain) - 1); + auto v = xft::load_avx512(mask, &res[i * ldres + j]); + v = _mm512_mul_ps(_mm512_set1_ps(gamma), v); + xft::store_avx512(&res[i * ldres + j], mask, v); + } + } + + GEMMVERBOSE("onednn_gemm_compute_resext", + onednn_gemm_compute(transA, M, N, K, alpha, A, lda, packedB, beta, C, ldc, bias, res, ldres, + matmul_kinds::Residential)); + } #elif defined(AVX512_FP16_WEIGHT_ONLY_FP16) if constexpr (std::is_same_v) { GEMMVERBOSE("xdnn_hgemm_f32f16f32_compute_resext", @@ -1297,11 +1399,13 @@ class MMHelper { if constexpr (std::is_same_v) { GEMMVERBOSE("xdnn_hgemm_compute_resext", xdnn_hgemm_compute_resext(transA, M, N, K, alpha, (const XDNN_FP16 *)A, lda, - (const XDNN_FP16 *)packedB, beta, (XDNN_FP16 *)C, ldc, bias, gamma, (const XDNN_FP16 *)res, ldres)); + (const XDNN_FP16 *)packedB, beta, (XDNN_FP16 *)C, ldc, bias, gamma, + (const XDNN_FP16 *)res, ldres)); } else if constexpr (std::is_same_v) { GEMMVERBOSE("xdnn_hgemm_f16f16f32_compute_resext", xdnn_hgemm_f16f16f32_compute_resext(transA, M, N, K, alpha, (const XDNN_FP16 *)A, lda, - (const XDNN_FP16 *)packedB, beta, C, ldc, bias, gamma, (const XDNN_FP16 *)res, ldres)); + (const XDNN_FP16 *)packedB, beta, C, ldc, bias, gamma, (const XDNN_FP16 *)res, + ldres)); } } #else @@ -1314,7 +1418,7 @@ class MMHelper { else if constexpr (std::is_same_v) { #ifdef AVX512_FP32_WEIGHT_ONLY_BF16 GEMMVERBOSE("xdnn_sgemm_f32bf16f32_compute_resext", - xdnn_sgemm_f32bf16f32_compute_resext(transA, M, N, K, alpha, A, lda, (const XDNN_UINT4x2 *)packedB, + xdnn_sgemm_f32bf16f32_compute_resext(transA, M, N, K, alpha, A, lda, (const XDNN_BF16 *)packedB, beta, C, ldc, bias, gamma, res, ldres)); #elif defined(AVX512_BF16_WEIGHT_ONLY_BF16) if constexpr (std::is_same_v) { @@ -1410,11 +1514,18 @@ class MMHelper { } } + int getEngineCount() { + int count = engine->get_count(kind); + return count; + } + private: dnnl::engine::kind kind; - dnnl::engine *engine; - dnnl::stream *stream; + dnnl::engine *engine; // For runtime engine + dnnl::stream *stream; // For runtime stream std::unordered_map> matmul_hub; + dnnl::engine *cpu_engine; + dnnl::stream *cpu_stream; int AMXThresholdM; @@ -1429,18 +1540,19 @@ class MMHelper { Resext, }; - std::string create_key(bool transA, int M, int N, int K, int matmul_kind) { - std::string key = std::to_string(transA) + "_" + std::to_string(M) + "_" + std::to_string(N) + "_" - + std::to_string(K) + "_" + std::to_string(matmul_kind); - return key; + template + std::string create_key(bool transA, int M, int N, int K, int matmul_kind, const Twei *packedB) { + std::stringstream key; + key << transA << "_" << M << "_" << N << "_" << K << "_" << matmul_kind << "_" << packedB; + return key.str(); } dnnl::memory::format_tag get_onednn_input_layout(dnnl::memory::data_type dt) { if (this->kind == dnnl::engine::kind::cpu) { - return dnnl::memory::format_tag::undef; + return dnnl::memory::format_tag::ab; } else if (this->kind == dnnl::engine::kind::gpu) { - return dnnl::memory::format_tag::AB32a16b; - // return dnnl::memory::format_tag::any; + return dnnl::memory::format_tag::ab; + // return dnnl::memory::format_tag::AB32a16b; } else { printf("[XFT][ERROR] Need a right engine kind in input layout."); std::exit(-1); @@ -1451,6 +1563,8 @@ class MMHelper { if (this->kind == dnnl::engine::kind::cpu) { if (dt == dnnl::memory::data_type::bf16) { return dnnl::memory::format_tag::BA16a64b2a; + } else if (dt == dnnl::memory::data_type::f16) { + return dnnl::memory::format_tag::BA16a64b; } else if (dt == dnnl::memory::data_type::s8) { return dnnl::memory::format_tag::BA16a64b4a; } else { @@ -1458,26 +1572,263 @@ class MMHelper { std::exit(-1); } } else if (this->kind == dnnl::engine::kind::gpu) { - return dnnl::memory::format_tag::BA4b8a8b2a; - // return dnnl::memory::format_tag::any; + return dnnl::memory::format_tag::ba; + // return dnnl::memory::format_tag::BA4b8a8b2a; } else { printf("[XFT][ERROR] Need a right engine kind in weight layout."); std::exit(-1); } } + dnnl::memory::format_tag get_onednn_bias_layout(dnnl::memory::data_type dt) { + if (this->kind == dnnl::engine::kind::cpu) { + return dnnl::memory::format_tag::ab; + } else if (this->kind == dnnl::engine::kind::gpu) { + return dnnl::memory::format_tag::ab; + } else { + printf("[XFT][ERROR] Need a right engine kind in bias layout."); + std::exit(-1); + } + } + + dnnl::memory::format_tag get_onednn_shift_layout(dnnl::memory::data_type dt) { + if (this->kind == dnnl::engine::kind::cpu) { + return dnnl::memory::format_tag::ab; + } else if (this->kind == dnnl::engine::kind::gpu) { + return dnnl::memory::format_tag::ab; + } else { + printf("[XFT][ERROR] Need a right engine kind in shift layout."); + std::exit(-1); + } + } + dnnl::memory::format_tag get_onednn_output_layout(dnnl::memory::data_type dt) { if (this->kind == dnnl::engine::kind::cpu) { - return dnnl::memory::format_tag::undef; + return dnnl::memory::format_tag::ab; } else if (this->kind == dnnl::engine::kind::gpu) { - return dnnl::memory::format_tag::AB32a16b; - // return dnnl::memory::format_tag::any; + return dnnl::memory::format_tag::ab; + // return dnnl::memory::format_tag::AB32a16b; } else { printf("[XFT][ERROR] Need a right engine kind in output layout."); std::exit(-1); } } + // Tin | Twei | Tout | Tbias | matmul + // --- | ---- | ---- | ----- | ------ + // f32 | f32 | f32 | f32 | sgemm + // f32 | f32 | f16 | f32 | sgemm_f32f32f16 + // f32 | f32 | bf16 | f32 | sgemm_f32f32bf16 + // f16 | f32 | f32 | f32 | sgemm_f16f32f32 + // bf16| f32 | f32 | f32 | sgemm_bf16f32f32 + // f16 | f32 | f16 | f32 | sgemm_f16f32f16 + // bf16| f32 | bf16 | f32 | sgemm_bf16f32bf16 + // f32 | f16 | f32 | f32 | hgemm_f32f16f32 + // f32 | f16 | f16 | f32 | hgemm_f32f16f16 + // f16 | f16 | f32 | f32 | hgemm_f16f16f32 + // f16 | f16 | f16 | f32 | hgemm + // f32 | bf16 | f32 | f32 | bgemm_f32bf16f32 + // f32 | bf16 | bf16 | f32 | bgemm_f32bf16bf16 + // bf16| bf16 | f32 | f32 | bgemm_bf16bf16f32 + // bf16| bf16 | bf16 | f32 | bgemm + template + void onednn_gemm_compute(bool transA, int M, int N, int K, float alpha, const Tin *A, int lda, const Twei *packedB, + float beta, Tout *C, int ldc, const Tbias *bias = nullptr, const Tin *res = nullptr, int ldres = -1, + const matmul_kinds postAlg = matmul_kinds::Basic) { + TimeLine t("onednn_gemm_compute"); + TimeLine t1("onednn_gemm_compute.create_primitive"); + using namespace dnnl; + using tag = memory::format_tag; + using dt = memory::data_type; + + dt input_dt; + dt weight_dt; + dt shift_dt; + if constexpr (std::is_same_v) { + input_dt = dt::f32; + weight_dt = dt::f32; + shift_dt = dt::f32; + } else if constexpr (std::is_same_v) { + input_dt = dt::bf16; + weight_dt = dt::bf16; + shift_dt = dt::bf16; + } else if constexpr (std::is_same_v) { + input_dt = dt::f16; + weight_dt = dt::f16; + shift_dt = dt::f16; + } else { + printf(">>> onednn_gemm_compute: input and weight date type not supported."); + exit(-1); + } + + dt output_dt; + if constexpr (std::is_same_v) { + output_dt = dt::f32; + } else if constexpr (std::is_same_v) { + output_dt = dt::bf16; + } else if constexpr (std::is_same_v) { + output_dt = dt::f16; + } else { + printf(">>> onednn_gemm_compute: output date type not supported."); + exit(-1); + } + + dt bias_dt; + if constexpr (std::is_same_v) { + bias_dt = dt::f32; + } else if constexpr (std::is_same_v) { + bias_dt = dt::bf16; + } else if constexpr (std::is_same_v) { + bias_dt = dt::f16; + } else { + printf(">>> onednn_gemm_compute: bias date type not supported."); + exit(-1); + } + + matmul::primitive_desc *matmul_pd; + matmul *matmul_prim; + std::string key = create_key(transA, M, N, K, postAlg, packedB); + auto it = matmul_hub.find(key); + if (it != matmul_hub.end()) { + matmul_pd = std::get<0>(it->second); + matmul_prim = std::get<1>(it->second); + } else { + // Source (A), weights (B) and destination (C) matrix dimensions. + memory::dims input_dims = {M, K}; + memory::dims weight_dims = {K, N}; + memory::dims output_dims = {M, N}; + memory::dims bias_dims = {1, N}; + memory::dims shift_dims = {M, N}; + + // Create memory descriptors and memory objects for src, weights, bias, and dst. + auto input_md = memory::desc(input_dims, input_dt, get_onednn_input_layout(input_dt)); + auto weight_md = memory::desc(weight_dims, weight_dt, get_onednn_weight_layout(weight_dt)); + auto output_md = memory::desc(output_dims, output_dt, get_onednn_output_layout(output_dt)); + auto bias_md = memory::desc(bias_dims, bias_dt, get_onednn_bias_layout(bias_dt)); + auto shift_md = memory::desc(shift_dims, shift_dt, get_onednn_shift_layout(shift_dt)); + + // Create primitive descriptor and primitive. + primitive_attr matmul_attr; + switch (postAlg) { + case matmul_kinds::Basic: { + break; + } + case matmul_kinds::Silu: { + const float post_alpha = 1.0f; + const float post_beta = 0.0f; + post_ops matmul_ops; + matmul_ops.append_eltwise(algorithm::eltwise_swish, post_alpha, post_beta); + matmul_attr.set_post_ops(matmul_ops); + break; + } + case matmul_kinds::Gelu: { + const float post_alpha = 1.0f; + const float post_beta = 0.0f; + post_ops matmul_ops; + matmul_ops.append_eltwise(algorithm::eltwise_gelu_tanh, post_alpha, post_beta); + matmul_attr.set_post_ops(matmul_ops); + break; + } + case matmul_kinds::Residential: { + if (res == nullptr) { + printf(">>> onednn_gemm_compute: Residential need be valuable."); + exit(-1); + } + + post_ops matmul_ops; + matmul_ops.append_binary(algorithm::binary_add, shift_md); + matmul_attr.set_post_ops(matmul_ops); + break; + } + default: { + printf(">>> onednn_gemm_compute: postAlg type %s not supported.", std::to_string(postAlg).c_str()); + exit(-1); + } + } + + if (postAlg == matmul_kinds::Basic) { + if (bias != nullptr) + matmul_pd = new matmul::primitive_desc(*engine, input_md, weight_md, bias_md, output_md); + else + matmul_pd = new matmul::primitive_desc(*engine, input_md, weight_md, output_md); + } else { + if (bias != nullptr) + matmul_pd + = new matmul::primitive_desc(*engine, input_md, weight_md, bias_md, output_md, matmul_attr); + else + matmul_pd = new matmul::primitive_desc(*engine, input_md, weight_md, output_md, matmul_attr); + } + + matmul_prim = new matmul(*matmul_pd); + + // Cache primitive_desc and matmul + std::string key = create_key(transA, M, N, K, postAlg, packedB); + std::tuple value(matmul_pd, matmul_prim); + matmul_hub[key] = value; + } + + // Repack and convert input data. + memory input_mem; + if constexpr (std::is_same_v) { + input_mem = memory(matmul_pd->src_desc(), *engine); + } else { + input_mem = memory(matmul_pd->src_desc(), *engine, const_cast(A)); + } + + memory weight_mem = memory(matmul_pd->weights_desc(), *engine, const_cast(packedB)); + memory output_mem = memory(matmul_pd->dst_desc(), *engine, C); + memory bias_mem; + if (bias != nullptr) { bias_mem = memory(matmul_pd->bias_desc(), *engine, const_cast(bias)); } + + memory shift_mem; + if (res != nullptr) { + memory::desc shift_md = memory::desc({M, N}, shift_dt, get_onednn_shift_layout(shift_dt)); + if constexpr (std::is_same_v) { + shift_mem = memory(shift_md, *engine); + } else { + shift_mem = memory(shift_md, *engine, const_cast(res)); + } + } + + // Create the primitive args. + std::unordered_map matmul_args; + matmul_args.insert({DNNL_ARG_SRC, input_mem}); + matmul_args.insert({DNNL_ARG_WEIGHTS, weight_mem}); + if (bias != nullptr) { matmul_args.insert({DNNL_ARG_BIAS, bias_mem}); } + if (res != nullptr) { matmul_args.insert({DNNL_ARG_ATTR_MULTIPLE_POST_OP(0) | DNNL_ARG_SRC_1, shift_mem}); } + matmul_args.insert({DNNL_ARG_DST, output_mem}); + t1.release(); + + // Executions. + TimeLine t2("onednn_gemm_compute.execute_primitive"); + // Reorder + if constexpr (std::is_same_v && !std::is_same_v) { +#pragma omp parallel for + for (uint64_t i = 0; i < M; ++i) { + void *input_ptr = input_mem.get_data_handle(); + if constexpr (std::is_same_v) { + bfloat16_t::cvt_float_to_bfloat16(A + i * lda, (bfloat16_t *)input_ptr + i * K, K); + if (res != nullptr) { + void *shift_ptr = shift_mem.get_data_handle(); + bfloat16_t::cvt_float_to_bfloat16(res + i * lda, (bfloat16_t *)shift_ptr + i * K, K); + } + } else if constexpr (std::is_same_v) { + float16_t::cvt_float_to_float16(A + i * lda, (float16_t *)input_ptr + i * K, K); + if (res != nullptr) { + void *shift_ptr = shift_mem.get_data_handle(); + float16_t::cvt_float_to_float16(res + i * lda, (float16_t *)shift_ptr + i * K, K); + } + } else { + printf(">>> onednn_gemm_compute: input and res date type convert not supported."); + exit(-1); + } + } + } + + matmul_prim->execute(*stream, matmul_args); + stream->wait(); + } + template void onednn_amx_sgemm_f32bf16f32_compute(bool transA, int M, int N, int K, float alpha, const Tin *A, int lda, const bfloat16_t *packedB, float beta, Tout *C, int ldc, const matmul_kinds postAlg = matmul_kinds::Basic) { @@ -1489,7 +1840,7 @@ class MMHelper { matmul::primitive_desc *matmul_pd; matmul *matmul_prim; - std::string key = create_key(transA, M, N, K, postAlg); + std::string key = create_key(transA, M, N, K, postAlg, packedB); auto it = matmul_hub.find(key); if (it != matmul_hub.end()) { matmul_pd = std::get<0>(it->second); @@ -1501,13 +1852,13 @@ class MMHelper { memory::dims output_dims = {M, N}; // Create memory descriptors and memory objects for src, weights, bias, and dst. - auto input_md = memory::desc(input_dims, dt::bf16, tag::ab); + auto input_md = memory::desc(input_dims, dt::bf16, get_onednn_input_layout(dt::bf16)); auto weight_md = memory::desc(weight_dims, dt::bf16, get_onednn_weight_layout(dt::bf16)); memory::desc output_md; if constexpr (std::is_same_v) { - output_md = memory::desc(output_dims, dt::f32, tag::ab); + output_md = memory::desc(output_dims, dt::f32, get_onednn_output_layout(dt::f32)); } else if constexpr (std::is_same_v) { - output_md = memory::desc(output_dims, dt::bf16, tag::ab); + output_md = memory::desc(output_dims, dt::bf16, get_onednn_output_layout(dt::bf16)); } else { printf(">>> onednn amx output date type not supported."); exit(-1); @@ -1545,7 +1896,7 @@ class MMHelper { } matmul_prim = new matmul(*matmul_pd); // Cache primitive_desc and matmul - std::string key = create_key(transA, M, N, K, postAlg); + std::string key = create_key(transA, M, N, K, postAlg, packedB); std::tuple value(matmul_pd, matmul_prim); matmul_hub[key] = value; } @@ -1595,7 +1946,7 @@ class MMHelper { matmul::primitive_desc *matmul_pd; matmul *matmul_prim; - std::string key = create_key(transA, M, N, K, matmul_kinds::BiasAdd); + std::string key = create_key(transA, M, N, K, matmul_kinds::BiasAdd, packedB); auto it = matmul_hub.find(key); if (it != matmul_hub.end()) { matmul_pd = std::get<0>(it->second); @@ -1625,7 +1976,7 @@ class MMHelper { matmul_prim = new matmul(*matmul_pd); // Cache primitive_desc and matmul - std::string key = create_key(transA, M, N, K, matmul_kinds::BiasAdd); + std::string key = create_key(transA, M, N, K, matmul_kinds::BiasAdd, packedB); std::tuple value(matmul_pd, matmul_prim); matmul_hub[key] = value; } @@ -1677,7 +2028,7 @@ class MMHelper { matmul::primitive_desc *matmul_pd; matmul *matmul_prim; - std::string key = create_key(transA, M, N, K, matmul_kinds::BiasAdd_Relu); + std::string key = create_key(transA, M, N, K, matmul_kinds::BiasAdd_Relu, packedB); auto it = matmul_hub.find(key); if (it != matmul_hub.end()) { matmul_pd = std::get<0>(it->second); @@ -1715,7 +2066,7 @@ class MMHelper { matmul_prim = new matmul(*matmul_pd); // Cache primitive_desc and matmul - std::string key = create_key(transA, M, N, K, matmul_kinds::BiasAdd_Relu); + std::string key = create_key(transA, M, N, K, matmul_kinds::BiasAdd_Relu, packedB); std::tuple value(matmul_pd, matmul_prim); matmul_hub[key] = value; } @@ -1767,7 +2118,7 @@ class MMHelper { matmul::primitive_desc *matmul_pd; matmul *matmul_prim; - std::string key = create_key(transA, M, N, K, matmul_kinds::Resmul); + std::string key = create_key(transA, M, N, K, matmul_kinds::Resmul, packedB); auto it = matmul_hub.find(key); if (it != matmul_hub.end()) { matmul_pd = std::get<0>(it->second); @@ -1806,7 +2157,7 @@ class MMHelper { matmul_prim = new matmul(*matmul_pd); // Cache primitive_desc and matmul - std::string key = create_key(transA, M, N, K, matmul_kinds::Resmul); + std::string key = create_key(transA, M, N, K, matmul_kinds::Resmul, packedB); std::tuple value(matmul_pd, matmul_prim); matmul_hub[key] = value; } @@ -1874,7 +2225,7 @@ class MMHelper { matmul::primitive_desc *matmul_pd; matmul *matmul_prim; - std::string key = create_key(transA, M, N, K, matmul_kinds::Residential); + std::string key = create_key(transA, M, N, K, matmul_kinds::Residential, packedB); auto it = matmul_hub.find(key); if (it != matmul_hub.end()) { matmul_pd = std::get<0>(it->second); @@ -1920,7 +2271,7 @@ class MMHelper { } // Cache primitive_desc and matmul - std::string key = create_key(transA, M, N, K, matmul_kinds::Residential); + std::string key = create_key(transA, M, N, K, matmul_kinds::Residential, packedB); std::tuple value(matmul_pd, matmul_prim); matmul_hub[key] = value; } @@ -1979,7 +2330,7 @@ class MMHelper { matmul::primitive_desc *matmul_pd; matmul *matmul_prim; - std::string key = create_key(transA, M, N, K, matmul_kinds::Basic); + std::string key = create_key(transA, M, N, K, matmul_kinds::Basic, B); auto it = matmul_hub.find(key); if (it != matmul_hub.end()) { matmul_pd = std::get<0>(it->second); @@ -2001,7 +2352,7 @@ class MMHelper { matmul_prim = new matmul(*matmul_pd); // Cache primitive_desc and matmul - std::string key = create_key(transA, M, N, K, matmul_kinds::Basic); + std::string key = create_key(transA, M, N, K, matmul_kinds::Basic, B); std::tuple value(matmul_pd, matmul_prim); matmul_hub[key] = value; } diff --git a/src/utils/simple_mem_pool.h b/src/utils/simple_mem_pool.h index 6fe36633..9ea074f9 100644 --- a/src/utils/simple_mem_pool.h +++ b/src/utils/simple_mem_pool.h @@ -24,7 +24,7 @@ class SimpleMemPool { private: - std::unordered_map> memoryMap; + std::unordered_map> memoryMap; // Private constructor to enforce Singleton pattern SimpleMemPool() {} @@ -46,7 +46,9 @@ class SimpleMemPool { } // Allocate or reallocate memory buffer based on name and size - void *getBuffer(const std::string &name, size_t size, size_t alignment = 64) { + void *getBuffer(const std::string &name, size_t size, void *device = nullptr, size_t alignment = 64) { + if (name.empty()) return nullptr; + if (size == 0) { // std::cout << "[Warning] Try to allocate 0 bytes for buffer:" << name << std::endl; return nullptr; @@ -55,17 +57,17 @@ class SimpleMemPool { if (it != memoryMap.end()) { // Buffer with the given name found - if (it->second.second >= size) { + if (std::get<1>(it->second) >= size) { // Existing buffer size is sufficient, return it - return it->second.first; + return std::get<0>(it->second); } else { // Reallocate the buffer - free(it->second.first); + xft::dealloc(std::get<0>(it->second), std::get<2>(it->second)); } } // Allocate new aligned buffer - void *buffer = xft::alloc(size, alignment); + void *buffer = xft::alloc(size, device, alignment); if (buffer == nullptr) { // Allocation failed std::cerr << "Memory allocation failed for buffer:" << name << " size:" << size << std::endl; @@ -73,16 +75,28 @@ class SimpleMemPool { } // Update or insert entry in the mapping - memoryMap[name] = std::make_pair(buffer, size); + memoryMap[name] = std::make_tuple(buffer, size, device); return buffer; } + // Free allocated memory based on name + void freeBuffer(const std::string &name) { + auto it = memoryMap.find(name); + + if (it != memoryMap.end()) { + xft::dealloc(std::get<0>(it->second), std::get<2>(it->second)); + memoryMap.erase(it->first); + } + } + // Destructor to free all allocated memory on program termination ~SimpleMemPool() { +#ifndef XFT_GPU for (auto &entry : memoryMap) { - free(entry.second.first); + xft::dealloc(std::get<0>(entry.second), std::get<2>(entry.second)); } +#endif memoryMap.clear(); } }; \ No newline at end of file diff --git a/src/utils/transpose_util.h b/src/utils/transpose_util.h index f2a07d43..0ae7ec00 100644 --- a/src/utils/transpose_util.h +++ b/src/utils/transpose_util.h @@ -14,7 +14,7 @@ // ============================================================================ #ifndef _TRANSPOSE_H #define _TRANSPOSE_H -#include + #include #include diff --git a/src/utils/type_selector.h b/src/utils/type_selector.h index 7708ccad..252e83ed 100644 --- a/src/utils/type_selector.h +++ b/src/utils/type_selector.h @@ -31,9 +31,11 @@ struct TypeSelector { using OutType = bfloat16_t; }; -// template <> -// struct TypeSelector { -// using InType = float16_t; -// using ImType = float16_t; -// using OutType = float16_t; -// }; \ No newline at end of file +#ifdef XFT_GPU +template <> +struct TypeSelector { + using InType = float16_t; + using ImType = float16_t; + using OutType = float16_t; +}; +#endif \ No newline at end of file diff --git a/src/utils/verbose.h b/src/utils/verbose.h index 9c207e37..542e83c5 100644 --- a/src/utils/verbose.h +++ b/src/utils/verbose.h @@ -41,10 +41,82 @@ class Printer { printf("xft_verbose,exec,cpu,api,%s,m%dn%dk%d,%.6lf\n", api_func, M, N, K, ms); fflush(stdout); } + static void matrix(int rows, int cols, int stride, size_t totalmem) { printf("xft_verbose,matrix:rows%d_cols%d_stride%d,use:%zu bytes of memory\n", rows, cols, stride, totalmem); fflush(stdout); } + + template + static void print(std::string buf_name, T *buf, int rows, int cols, int stride, bool printAll = false, + void *device = nullptr) { + std::cout << buf_name.c_str() << ":" << std::endl; +#ifdef XFT_GPU + if (device != nullptr) { + sycl::queue *gpu_queue = static_cast(device); + gpu_queue + ->submit([&](sycl::handler &cgh) { + auto out = sycl::stream(10240, 7680, cgh); + cgh.parallel_for(sycl::nd_range<1>(1, 1), [=](sycl::nd_item<1> item) { + int idx_col = item.get_global_id(0); + if (idx_col == 0) { + if (printAll == false) { + for (int row = 0; row < 6; ++row) { + for (int col = 0; col < 6; ++col) { + out << (float)buf[row * stride + col] << ", "; + } + out << " ... "; + for (int col = cols - 6; col < cols; ++col) { + out << (float)buf[row * stride + col] << ", "; + } + out << sycl::endl; + } + out << "..." << sycl::endl; + for (int row = rows - 6; row < rows; ++row) { + for (int col = 0; col < 6; ++col) { + out << (float)buf[row * stride + col] << ", "; + } + out << " ... "; + for (int col = cols - 6; col < cols; ++col) { + out << (float)buf[row * stride + col] << ", "; + } + out << sycl::endl; + } + out << sycl::endl; + } else { + for (int row = 0; row < rows; ++row) { + for (int col = 0; col < 6; ++col) { + out << (float)buf[row * stride + col] << ", "; + } + out << " ... "; + for (int col = cols - 6; col < cols; ++col) { + out << (float)buf[row * stride + col] << ", "; + } + out << sycl::endl; + } + } + } + }); + }) + .wait(); + } else { + for (int row = 0; row < 6; ++row) { + for (int col = 0; col < 6; ++col) { + std::cout << (float)buf[row * stride + col] << ", "; + } + std::cout << std::endl; + } + std::cout << "..." << std::endl; + for (int row = rows - 6; row < rows; ++row) { + for (int col = cols - 6; col < cols; ++col) { + std::cout << (float)buf[row * stride + col] << ", "; + } + std::cout << std::endl; + } + std::cout << std::endl; + } +#endif + } }; #define GEMMVERBOSE(api_func, compute_func) \ diff --git a/tests/ut/attention_kernels_test.cpp b/tests/ut/attention_kernels_test.cpp index 57ba121a..09ed540d 100644 --- a/tests/ut/attention_kernels_test.cpp +++ b/tests/ut/attention_kernels_test.cpp @@ -86,7 +86,8 @@ static void selfAttentionRef(bfloat16_t *output, bfloat16_t *query, bfloat16_t * int kvHeadNum, int headSize, int oStride, int qStride, int kvStride, int batchSize, const int *tokenSizes, const float scale) { - int rowOffsets[batchSize] = {0}; + int rowOffsets[batchSize]; + memset(rowOffsets, 0 , batchSize * sizeof(int)); for (int i = 1; i < batchSize; i++) { rowOffsets[i] = rowOffsets[i - 1] + tokenSizes[i - 1]; } @@ -178,49 +179,49 @@ void testSelfAttention( } TEST(AttentionKernelsTest, SeparateCopyTest1) { - int batchSize = 1; + const int batchSize = 1; int tokenSizes[batchSize] = {80}; testSelfAttention(128, 2, 2, tokenSizes, batchSize); } TEST(AttentionKernelsTest, SeparateCopyTest2) { - int batchSize = 1; + const int batchSize = 1; int tokenSizes[batchSize] = {100}; testSelfAttention(128, 6, 2, tokenSizes, batchSize); } TEST(AttentionKernelsTest, SeparateCopyTest3) { - int batchSize = 2; + const int batchSize = 2; int tokenSizes[batchSize] = {100, 200}; testSelfAttention(128, 8, 2, tokenSizes, batchSize); } TEST(AttentionKernelsTest, SeparateCopyTest4) { - int batchSize = 3; + const int batchSize = 3; int tokenSizes[batchSize] = {100, 101, 102}; testSelfAttention(128, 8, 2, tokenSizes, batchSize); } TEST(AttentionKernelsTest, SeparateCopyTest5) { - int batchSize = 4; + const int batchSize = 4; int tokenSizes[batchSize] = {100, 55, 111, 203}; testSelfAttention(128, 8, 2, tokenSizes, batchSize); } TEST(AttentionKernelsTest, FusedCopyTest1) { - int batchSize = 1; + const int batchSize = 1; int tokenSizes[batchSize] = {100}; testSelfAttention(128, 2, 2, tokenSizes, batchSize, false); } TEST(AttentionKernelsTest, FusedCopyTest2) { - int batchSize = 2; + const int batchSize = 2; int tokenSizes[batchSize] = {100, 101}; testSelfAttention(128, 4, 4, tokenSizes, batchSize, false); } TEST(AttentionKernelsTest, FusedCopyTest3) { - int batchSize = 4; + const int batchSize = 4; int tokenSizes[batchSize] = {100, 101, 102, 103}; testSelfAttention(128, 4, 4, tokenSizes, batchSize, false); } diff --git a/tests/ut/cross_attention_test.cpp b/tests/ut/cross_attention_test.cpp index 980b1307..c6694a77 100644 --- a/tests/ut/cross_attention_test.cpp +++ b/tests/ut/cross_attention_test.cpp @@ -97,7 +97,7 @@ static void crossAttentionRef(bfloat16_t *output, const bfloat16_t *query, const // Score = Softmax(Q * Kᵀ) softmaxRef(pscore, presentSeqLen); -#ifdef DEBUG +#ifdef XFT_DEBUG printf("pscore: "); for (int i = 0; i < presentSeqLen; ++i) { printf("%.6f ", pscore[i]); diff --git a/tests/ut/kv_reorder_test.cpp b/tests/ut/kv_reorder_test.cpp index f082bc7b..17890026 100644 --- a/tests/ut/kv_reorder_test.cpp +++ b/tests/ut/kv_reorder_test.cpp @@ -14,7 +14,7 @@ // ============================================================================ #include -#include "opt_decoder.h" +#include "kvcache_tensor.h" #include "gtest/gtest.h" template diff --git a/tests/ut/layers_attention_test.cpp b/tests/ut/layers_attention_test.cpp index 942393c7..134020f3 100644 --- a/tests/ut/layers_attention_test.cpp +++ b/tests/ut/layers_attention_test.cpp @@ -87,14 +87,17 @@ void test_AttentionLLaMA(void) { int nextTokenNum = 1; compareAttentionLLaMA(step++, batchSize, inputSeqLen, pastSeqLen, currentSeqLen, attHeadDim, attHeadNum, - kvHeadNum, maxPositions, maxPosEmbed, hiddenSize, qkvProj, qkvProj + qSize, qkvProj + kvSize, oProj); + kvHeadNum, maxPositions, maxPosEmbed, hiddenSize, qkvProj, qkvProj + qSize, qkvProj + qSize + kvSize, + oProj); pastSeqLen += inputSeqLen; currentSeqLen = nextTokenNum; compareAttentionLLaMA(step++, batchSize, inputSeqLen, pastSeqLen, currentSeqLen, attHeadDim, attHeadNum, - kvHeadNum, maxPositions, maxPosEmbed, hiddenSize, qkvProj, qkvProj + qSize, qkvProj + kvSize, oProj); + kvHeadNum, maxPositions, maxPosEmbed, hiddenSize, qkvProj, qkvProj + qSize, qkvProj + qSize + kvSize, + oProj); pastSeqLen += nextTokenNum; compareAttentionLLaMA(step++, batchSize, inputSeqLen, pastSeqLen, currentSeqLen, attHeadDim, attHeadNum, - kvHeadNum, maxPositions, maxPosEmbed, hiddenSize, qkvProj, qkvProj + qSize, qkvProj + kvSize, oProj); + kvHeadNum, maxPositions, maxPosEmbed, hiddenSize, qkvProj, qkvProj + qSize, qkvProj + qSize + kvSize, + oProj); free(qkvProj); free(oProj); diff --git a/tests/ut/rotary_embedding_test.cpp b/tests/ut/rotary_embedding_test.cpp index c339e6ec..0008b320 100644 --- a/tests/ut/rotary_embedding_test.cpp +++ b/tests/ut/rotary_embedding_test.cpp @@ -26,7 +26,7 @@ static bool compare(const float *result, const float *ground_truth, const int si } TEST(RotrayEmbedding, RotrayEmbeddingTest) { - int bs = 2, seq = 2, headnum = 2, dim = 2; + const int bs = 2, seq = 2, headnum = 2, dim = 2; int max_len = 10; int stride = bs * seq, size = bs * seq * headnum * dim; @@ -57,7 +57,7 @@ TEST(RotrayEmbedding, RotrayEmbeddingTest) { } TEST(RotrayEmbedding, BF16Test) { - int bs = 2, seq = 2, headnum = 2, dim = 2; + const int bs = 2, seq = 2, headnum = 2, dim = 2; int max_len = 10; int stride = bs * seq, size = bs * seq * headnum * dim;