Skip to content

Commit

Permalink
[Kernel] Add GPU kernels and enable LLaMA model. (#372)
Browse files Browse the repository at this point in the history
  • Loading branch information
changqi1 authored Jun 14, 2024
1 parent 24242ff commit 80df391
Show file tree
Hide file tree
Showing 42 changed files with 1,336 additions and 302 deletions.
12 changes: 6 additions & 6 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ project(xfastertransformer LANGUAGES C CXX)
option(WITH_GPU "Build with GPU" OFF)
if(WITH_GPU)
message(STATUS "Notice: Building with GPU.")
add_definitions(-DGPU=true)
add_definitions(-DXFT_GPU=true)
# Get compiler version
execute_process(COMMAND ${CMAKE_CXX_COMPILER} --version
OUTPUT_VARIABLE ICPX_VERSION
Expand All @@ -35,10 +35,6 @@ else()
message(STATUS "Notice: GCC version: ${GCC_VERSION}")
endif()

if(NOT CMAKE_BUILD_TYPE)
set(CMAKE_BUILD_TYPE Release)
endif()

set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mavx512f -mavx512bw -mavx512vl -fPIC")
if(WITH_GPU)
Expand Down Expand Up @@ -73,11 +69,15 @@ if(GCC_VERSION VERSION_GREATER_EQUAL "10.1")
endif()
endif()

if(NOT CMAKE_BUILD_TYPE)
set(CMAKE_BUILD_TYPE Release)
endif()

if(CMAKE_BUILD_TYPE MATCHES "Debug")
message(STATUS "Notice: Using Debug mode.")
set(CMAKE_C_FLAGS "-O0 -g")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O0 -g")
add_definitions(-DDEBUG=true)
add_definitions(-DXFT_DEBUG=true)
add_definitions(-DSTEP_BY_STEP_ATTN=true)
else()
message(STATUS "Notice: Using Release mode.")
Expand Down
8 changes: 8 additions & 0 deletions requirements-gpu.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
-f https://download.pytorch.org/whl/torch_stable.html
cmake==3.26.1
sentencepiece==0.1.99
torch==2.3.0+cpu.cxx11.abi
transformers==4.40.0
accelerate==0.23.0
protobuf
tiktoken
59 changes: 56 additions & 3 deletions src/common/allocator.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,13 @@
#pragma once
#include <cstdio>
#include <cstdlib>
#include <sys/mman.h>
#include <cstring>
#include "environment.h"
#include <sys/mman.h>

#ifdef XFT_GPU
#include <CL/sycl.hpp>
#endif

namespace xft {

Expand All @@ -26,10 +31,22 @@ static inline bool is_thp_alloc(size_t nbytes) {
return (Env::getInstance().getTHPEnabled() && (nbytes >= g_thp_threshold));
}

static inline void *alloc(size_t nbytes, size_t alignment = 64) {
static inline void *alloc(size_t nbytes, void *device = nullptr, size_t alignment = 64) {
if (nbytes == 0) { return nullptr; }

void *data;
void *data = nullptr;

#ifdef XFT_GPU
if (device != nullptr) {
sycl::queue *gpu_queue = static_cast<sycl::queue *>(device);
data = sycl::malloc_device<char>(nbytes, *gpu_queue);
if (data == nullptr) {
printf("Unable to allocate buffer with size of %zu in GPU.\n", nbytes);
exit(-1);
}
return data;
}
#endif

int err = posix_memalign(&data, alignment, nbytes);
if (err != 0) {
Expand All @@ -47,4 +64,40 @@ static inline void *alloc(size_t nbytes, size_t alignment = 64) {

return data;
}

static inline void dealloc(void *data, void *device = nullptr) {
#ifdef XFT_GPU
if (device != nullptr) {
sycl::free(data, *static_cast<sycl::queue *>(device));
return;
}
#endif

free(data);
}

static inline void memcopy(void *dst, const void *src, size_t size, void *device = nullptr) {
#ifdef XFT_GPU
if (device != nullptr) {
sycl::queue *gpu_queue = static_cast<sycl::queue *>(device);
gpu_queue->memcpy(dst, src, size).wait();
return;
}
#endif

memcpy(dst, src, size);
}

static inline void memsetv(void *dst, int ch, size_t size, void *device = nullptr) {
#ifdef XFT_GPU
if (device != nullptr) {
sycl::queue *gpu_queue = static_cast<sycl::queue *>(device);
gpu_queue->memset(dst, ch, size).wait();
return;
}
#endif

memset(dst, ch, size);
}

} // namespace xft
77 changes: 74 additions & 3 deletions src/common/sequence.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include <queue>
#include <unordered_map>

#include "allocator.h"
#include "environment.h"
#include "sampling_params.h"

Expand Down Expand Up @@ -67,7 +68,7 @@ class SequenceIDManager {
// The SequenceMeta is one sequence of batch inputs and includes the generated tokens.
class SequenceMeta {
public:
SequenceMeta(std::vector<int32_t> &_promptTokens)
SequenceMeta(const std::vector<int32_t> &_promptTokens)
: sequenceID(SequenceIDManager::getInstance().createSequenceID())
, inputSeqLen(_promptTokens.size())
, pastSeqLen(0)
Expand All @@ -81,6 +82,16 @@ class SequenceMeta {
, promptTokens(_inputSeqLen, 0)
, step(0) {}

SequenceMeta(int32_t _sequenceID, const std::vector<int32_t> &_promptTokens)
: sequenceID(_sequenceID)
, inputSeqLen(_promptTokens.size())
, pastSeqLen(0)
, promptTokens(_promptTokens)
, step(0) {}

SequenceMeta(int32_t _sequenceID, int32_t _inputSeqLen)
: sequenceID(_sequenceID), inputSeqLen(_inputSeqLen), pastSeqLen(0), promptTokens(_inputSeqLen, 0), step(0) {}

~SequenceMeta() {}

int32_t getSequenceID() const { return sequenceID; }
Expand Down Expand Up @@ -175,7 +186,8 @@ class SequenceGroupMeta {
groupID = sequences[0].getSequenceID();
}

SequenceGroupMeta(std::vector<int32_t> &_inputTokens, SamplingMeta &samplingMeta_) : samplingMeta(samplingMeta_) {
SequenceGroupMeta(const std::vector<int32_t> &_inputTokens, SamplingMeta &samplingMeta_)
: samplingMeta(samplingMeta_) {
sequences.reserve(samplingMeta.config.numBeams);
for (int i = 0; i < samplingMeta.config.numBeams; ++i) {
sequences.emplace_back(SequenceMeta(_inputTokens));
Expand All @@ -191,7 +203,7 @@ class SequenceGroupMeta {
groupID = sequences[0].getSequenceID();
}

SequenceGroupMeta(std::vector<int32_t> &_inputTokens) {
SequenceGroupMeta(const std::vector<int32_t> &_inputTokens) {
sequences.reserve(samplingMeta.config.numBeams);
for (int i = 0; i < samplingMeta.config.numBeams; ++i) {
sequences.emplace_back(SequenceMeta(_inputTokens));
Expand All @@ -207,6 +219,40 @@ class SequenceGroupMeta {
groupID = sequences[0].getSequenceID();
}

SequenceGroupMeta(int32_t _sequenceID, const std::vector<int32_t> &_inputTokens, SamplingMeta &samplingMeta_)
: samplingMeta(samplingMeta_) {
sequences.reserve(samplingMeta.config.numBeams);
for (int i = 0; i < samplingMeta.config.numBeams; ++i) {
sequences.emplace_back(SequenceMeta(_sequenceID, _inputTokens));
}
groupID = sequences[0].getSequenceID();
}

SequenceGroupMeta(int32_t _sequenceID, int32_t _inputSeqLen, SamplingMeta &samplingMeta_)
: samplingMeta(samplingMeta_) {
sequences.reserve(samplingMeta.config.numBeams);
for (int i = 0; i < samplingMeta.config.numBeams; ++i) {
sequences.emplace_back(SequenceMeta(_sequenceID, _inputSeqLen));
}
groupID = sequences[0].getSequenceID();
}

SequenceGroupMeta(int32_t _sequenceID, const std::vector<int32_t> &_inputTokens) {
sequences.reserve(samplingMeta.config.numBeams);
for (int i = 0; i < samplingMeta.config.numBeams; ++i) {
sequences.emplace_back(SequenceMeta(_sequenceID, _inputTokens));
}
groupID = sequences[0].getSequenceID();
}

SequenceGroupMeta(int32_t _sequenceID, int32_t _inputSeqLen) {
sequences.reserve(samplingMeta.config.numBeams);
for (int i = 0; i < samplingMeta.config.numBeams; ++i) {
sequences.emplace_back(SequenceMeta(_sequenceID, _inputSeqLen));
}
groupID = sequences[0].getSequenceID();
}

int32_t getGroupID() { return groupID; }

int32_t getGroupSize() { return samplingMeta.config.numBeams; }
Expand Down Expand Up @@ -272,6 +318,31 @@ class SequencePool {
return group;
}

SequenceGroupMeta *newGroupMeta(
int32_t sequenceID, std::vector<int32_t> &inputTokens, SamplingMeta &samplingMeta_) {
auto *group = new SequenceGroupMeta(sequenceID, inputTokens, samplingMeta_);
this->add(group);
return group;
}

SequenceGroupMeta *newGroupMeta(int32_t sequenceID, int32_t inputSeqLen, SamplingMeta &samplingMeta_) {
auto *group = new SequenceGroupMeta(sequenceID, inputSeqLen, samplingMeta_);
this->add(group);
return group;
}

SequenceGroupMeta *newGroupMeta(int32_t sequenceID, std::vector<int32_t> &inputTokens) {
auto *group = new SequenceGroupMeta(sequenceID, inputTokens);
this->add(group);
return group;
}

SequenceGroupMeta *newGroupMeta(int32_t sequenceID, int32_t inputSeqLen) {
auto *group = new SequenceGroupMeta(sequenceID, inputSeqLen);
this->add(group);
return group;
}

bool add(SequenceGroupMeta *sequenceGroup, bool force = false) {
int32_t groupID = sequenceGroup->getGroupID();
bool isSuccess = false;
Expand Down
32 changes: 22 additions & 10 deletions src/common/transformer_ctx.h
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,8 @@ struct DecoderContext {
xft::Matrix<float> qkvMatMul; // query, key, value
xft::Matrix<float> imOut; // intermediate output

MMHelper *mmHelper;
MMHelper *mmHelper = nullptr;
void *device = nullptr;

std::string configPath;
INIReader configReader;
Expand All @@ -130,7 +131,7 @@ struct DecoderContext {
public:
DecoderContext(int _layers, int _hiddenSize, int _headSize, int _attHeadNum, int _kvHeadNum, int _imSize, const std::string &act,
float epsilon, int _vocabSize, int _embeddingSize, int _maxPositions, int _maxPosEmbed, int _maxSeqLength,
int _splitIdx, int _splits, int _ppSize = 1, int _ppRank = 0, RopeParams *_ropeParamsPtr = nullptr,
int _splitIdx, int _splits, MMHelper *mmHelper, void *device = nullptr, int _ppSize = 1, int _ppRank = 0, RopeParams *_ropeParamsPtr = nullptr,
bool _useLogN = true, bool _useNTK = true, int numThreads = 0)
: layers(_layers)
, hiddenSize(_hiddenSize)
Expand Down Expand Up @@ -170,9 +171,12 @@ struct DecoderContext {
}
}

this->mmHelper = mmHelper;
this->device = device;

this->rawBufSize = 4 * 32 * intermediateSize + 4 * attHeadNum * 32 * 32; // assume bs=4, seq=32
this->rawBuffer = (float *)xft::alloc(sizeof(float) * rawBufSize);
memset(this->rawBuffer, 0, sizeof(float) * rawBufSize);
this->rawBuffer = (float *)xft::alloc(sizeof(float) * rawBufSize, this->device);
xft::memsetv(this->rawBuffer, 0, sizeof(float) * rawBufSize, this->device);

if (act == "relu") {
this->actType = RELU;
Expand Down Expand Up @@ -240,8 +244,12 @@ struct DecoderContext {
bool cached(const std::string &name) { return SimpleMemPool::instance().cached(name); }

template <typename T>
T *getBuffer(const std::string &name, size_t size, size_t alignment = 64) {
return (T *)SimpleMemPool::instance().getBuffer(name, sizeof(T) * size, alignment);
T *getBuffer(const std::string &name, size_t size, void *device = nullptr, size_t alignment = 64) {
return (T *)SimpleMemPool::instance().getBuffer(name, sizeof(T) * size, device, alignment);
}

void freeBuffer(const std::string &name) {
SimpleMemPool::instance().freeBuffer(name);
}

void dump() {
Expand Down Expand Up @@ -286,10 +294,10 @@ struct DecoderContext {
uint64_t total = size1 + size2 + size3;
if (total > this->rawBufSize) {
this->rawBufSize = total;
free(this->rawBuffer);
if (this->rawBuffer) xft::dealloc(this->rawBuffer, this->device);

this->rawBuffer = (float *)xft::alloc(sizeof(float) * rawBufSize);
memset(this->rawBuffer, 0, sizeof(float) * rawBufSize);
this->rawBuffer = (float *)xft::alloc(sizeof(float) * rawBufSize, this->device);
xft::memsetv(this->rawBuffer, 0, sizeof(float) * rawBufSize, this->device);
}

// Assign the buffer
Expand All @@ -312,5 +320,9 @@ struct DecoderContext {
return rawBufSize - size1 - size2;
}

~DecoderContext() { free(this->rawBuffer); }
~DecoderContext() {
#ifndef XFT_GPU
if (this->rawBuffer) xft::dealloc(this->rawBuffer, this->device);
#endif
}
};
6 changes: 3 additions & 3 deletions src/kernels/attention_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ void crossAttention(bfloat16_t *output, bfloat16_t *query, bfloat16_t *key, bflo
small_sgemm_bf16bf16f32_b(true, m, n, k, (XDNN_BF16 *)A, lda, (XDNN_BF16 *)baseB, ldb, C, ldc, blkIndices,
cacheBlkStride, cacheBlkSize);

#ifdef DEBUG
#ifdef XFT_DEBUG
if (b == 0 && i == 0) {
printf("Q * K, first head:\n");
auto p = C;
Expand All @@ -78,7 +78,7 @@ void crossAttention(bfloat16_t *output, bfloat16_t *query, bfloat16_t *key, bflo
// Softmax(Q * K)
small_softmax_f32(C, scale, n);

#ifdef DEBUG
#ifdef XFT_DEBUG
if (b == 0 && i == 0) {
printf("Softmax(Q * K), first head:\n");
auto p = C;
Expand All @@ -100,7 +100,7 @@ void crossAttention(bfloat16_t *output, bfloat16_t *query, bfloat16_t *key, bflo
small_sgemm_f32bf16bf16_b(false, m, n, k, C, lda, (XDNN_BF16 *)baseB, ldb, (XDNN_BF16 *)baseC, ldc,
blkIndices, cacheBlkStride, cacheBlkSize);

#ifdef DEBUG
#ifdef XFT_DEBUG
if (b == 0 && i == 0) {
printf("Softmax(Q * K) * V, first head:\n");
auto p = C;
Expand Down
Loading

0 comments on commit 80df391

Please sign in to comment.