diff --git a/LLama/Extensions/IModelParamsExtensions.cs b/LLama/Extensions/IModelParamsExtensions.cs index 83924a2f1..93b0f86ea 100644 --- a/LLama/Extensions/IModelParamsExtensions.cs +++ b/LLama/Extensions/IModelParamsExtensions.cs @@ -6,6 +6,9 @@ namespace LLama.Extensions { + /// + /// Extention methods to the IModelParams interface + /// public static class IModelParamsExtensions { /// @@ -31,7 +34,7 @@ public static MemoryHandle ToLlamaContextParams(this IModelParams @params, out L result.n_gpu_layers = @params.GpuLayerCount; result.seed = @params.Seed; result.f16_kv = @params.UseFp16Memory; - result.use_mmap = @params.UseMemoryLock; + result.use_mmap = @params.UseMemorymap; result.use_mlock = @params.UseMemoryLock; result.logits_all = @params.Perplexity; result.embedding = @params.EmbeddingMode; diff --git a/LLama/LLamaEmbedder.cs b/LLama/LLamaEmbedder.cs index 24b6ee80b..a74f11ee2 100644 --- a/LLama/LLamaEmbedder.cs +++ b/LLama/LLamaEmbedder.cs @@ -55,7 +55,7 @@ public unsafe float[] GetEmbeddings(string text, int threads = -1, bool addBos = text = text.Insert(0, " "); } - var embed_inp_array = Utils.Tokenize(_ctx, text, addBos, Encoding.GetEncoding(encoding)).ToArray(); + var embed_inp_array = _ctx.Tokenize(text, addBos, Encoding.GetEncoding(encoding)); // TODO(Rinne): deal with log of prompt diff --git a/LLama/LLamaInstructExecutor.cs b/LLama/LLamaInstructExecutor.cs index e055c1475..901b347a0 100644 --- a/LLama/LLamaInstructExecutor.cs +++ b/LLama/LLamaInstructExecutor.cs @@ -30,8 +30,8 @@ public class InstructExecutor : StatefulExecutorBase public InstructExecutor(LLamaModel model, string instructionPrefix = "\n\n### Instruction:\n\n", string instructionSuffix = "\n\n### Response:\n\n") : base(model) { - _inp_pfx = _model.Tokenize(instructionPrefix, true).ToArray(); - _inp_sfx = _model.Tokenize(instructionSuffix, false).ToArray(); + _inp_pfx = _model.Tokenize(instructionPrefix, true); + _inp_sfx = _model.Tokenize(instructionSuffix, false); _instructionPrefix = instructionPrefix; } @@ -133,7 +133,7 @@ protected override void PreprocessInputs(string text, InferStateArgs args) _embed_inps.AddRange(_inp_sfx); - args.RemainedTokens -= line_inp.Count(); + args.RemainedTokens -= line_inp.Length; } } /// @@ -146,9 +146,7 @@ protected override bool PostProcess(IInferenceParams inferenceParams, InferState { string last_output = ""; foreach (var id in _last_n_tokens) - { - last_output += Utils.PtrToString(NativeApi.llama_token_to_str(_model.NativeHandle, id), _model.Encoding); - } + last_output += _model.NativeHandle.TokenToString(id, _model.Encoding); foreach (var antiprompt in args.Antiprompts) { diff --git a/LLama/LLamaInteractExecutor.cs b/LLama/LLamaInteractExecutor.cs index f5c1583ec..fef6a4d11 100644 --- a/LLama/LLamaInteractExecutor.cs +++ b/LLama/LLamaInteractExecutor.cs @@ -25,7 +25,7 @@ public class InteractiveExecutor : StatefulExecutorBase /// public InteractiveExecutor(LLamaModel model) : base(model) { - _llama_token_newline = Utils.Tokenize(_model.NativeHandle, "\n", false, _model.Encoding).ToArray(); + _llama_token_newline = _model.NativeHandle.Tokenize("\n", false, _model.Encoding); } /// @@ -114,7 +114,7 @@ protected override void PreprocessInputs(string text, InferStateArgs args) } var line_inp = _model.Tokenize(text, false); _embed_inps.AddRange(line_inp); - args.RemainedTokens -= line_inp.Count(); + args.RemainedTokens -= line_inp.Length; } } @@ -133,7 +133,7 @@ protected override bool PostProcess(IInferenceParams inferenceParams, InferState string last_output = ""; foreach (var id in _last_n_tokens) { - last_output += Utils.PtrToString(NativeApi.llama_token_to_str(_model.NativeHandle, id), _model.Encoding); + last_output += _model.NativeHandle.TokenToString(id, _model.Encoding); } foreach (var antiprompt in args.Antiprompts) diff --git a/LLama/LLamaModel.cs b/LLama/LLamaModel.cs index 2bd31199f..526ffcb2b 100644 --- a/LLama/LLamaModel.cs +++ b/LLama/LLamaModel.cs @@ -64,10 +64,9 @@ public LLamaModel(IModelParams Params, string encoding = "UTF-8", ILLamaLogger? /// /// Whether to add a bos to the text. /// - public IEnumerable Tokenize(string text, bool addBos = true) + public llama_token[] Tokenize(string text, bool addBos = true) { - // TODO: reconsider whether to convert to array here. - return Utils.Tokenize(_ctx, text, addBos, _encoding); + return _ctx.Tokenize(text, addBos, _encoding); } /// @@ -79,9 +78,7 @@ public string DeTokenize(IEnumerable tokens) { StringBuilder sb = new(); foreach(var token in tokens) - { - sb.Append(Utils.PtrToString(NativeApi.llama_token_to_str(_ctx, token), _encoding)); - } + sb.Append(_ctx.TokenToString(token, _encoding)); return sb.ToString(); } @@ -285,8 +282,8 @@ public LLamaTokenDataArray ApplyPenalty(IEnumerable lastTokens, Dic int repeatLastTokensCount = 64, float repeatPenalty = 1.1f, float alphaFrequency = .0f, float alphaPresence = .0f, bool penalizeNL = true) { - var n_vocab = NativeApi.llama_n_vocab(_ctx); - var logits = Utils.GetLogits(_ctx, n_vocab); + var n_vocab = _ctx.VocabCount; + var logits = _ctx.GetLogits(); // Apply params.logit_bias map if(logitBias is not null) @@ -338,7 +335,7 @@ public llama_token Eval(llama_token[] tokens, llama_token pastTokensCount) n_eval = Params.BatchSize; } - if(Utils.Eval(_ctx, tokens, i, n_eval, pastTokensCount, Params.Threads) != 0) + if (!_ctx.Eval(tokens.AsMemory(i, n_eval), pastTokensCount, Params.Threads)) { _logger?.Log(nameof(LLamaModel), "Failed to eval.", ILLamaLogger.LogLevel.Error); throw new RuntimeError("Failed to eval."); @@ -353,9 +350,7 @@ public llama_token Eval(llama_token[] tokens, llama_token pastTokensCount) internal IEnumerable GenerateResult(IEnumerable ids) { foreach(var id in ids) - { - yield return Utils.TokenToString(id, _ctx, _encoding); - } + yield return _ctx.TokenToString(id, _encoding); } /// diff --git a/LLama/LLamaStatelessExecutor.cs b/LLama/LLamaStatelessExecutor.cs index 06b5159d2..9075edd14 100644 --- a/LLama/LLamaStatelessExecutor.cs +++ b/LLama/LLamaStatelessExecutor.cs @@ -31,7 +31,7 @@ public StatelessExecutor(LLamaModel model) _model = model; var tokens = model.Tokenize(" ", true).ToArray(); - Utils.Eval(_model.NativeHandle, tokens, 0, tokens.Length, 0, _model.Params.Threads); + _model.NativeHandle.Eval(tokens.AsMemory(0, tokens.Length), 0, _model.Params.Threads); _originalState = model.GetState(); } @@ -52,7 +52,7 @@ public IEnumerable Infer(string text, IInferenceParams? inferenceParams List tokens = _model.Tokenize(text, true).ToList(); int n_prompt_tokens = tokens.Count; - Utils.Eval(_model.NativeHandle, tokens.ToArray(), 0, n_prompt_tokens, n_past, _model.Params.Threads); + _model.NativeHandle.Eval(tokens.ToArray().AsMemory(0, n_prompt_tokens), n_past, _model.Params.Threads); lastTokens.AddRange(tokens); n_past += n_prompt_tokens; @@ -76,7 +76,7 @@ public IEnumerable Infer(string text, IInferenceParams? inferenceParams lastTokens.Add(id); - string response = Utils.TokenToString(id, _model.NativeHandle, _model.Encoding); + string response = _model.NativeHandle.TokenToString(id, _model.Encoding); yield return response; tokens.Clear(); @@ -87,7 +87,7 @@ public IEnumerable Infer(string text, IInferenceParams? inferenceParams string last_output = ""; foreach (var token in lastTokens) { - last_output += Utils.PtrToString(NativeApi.llama_token_to_str(_model.NativeHandle, token), _model.Encoding); + last_output += _model.NativeHandle.TokenToString(token, _model.Encoding); } bool should_break = false; diff --git a/LLama/Native/NativeApi.cs b/LLama/Native/NativeApi.cs index 4e0ac2a29..edfb41528 100644 --- a/LLama/Native/NativeApi.cs +++ b/LLama/Native/NativeApi.cs @@ -207,6 +207,17 @@ public static ulong llama_set_state_data(SafeLLamaContextHandle ctx, byte[] src) [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] public static extern int llama_eval(SafeLLamaContextHandle ctx, llama_token[] tokens, int n_tokens, int n_past, int n_threads); + /// + /// Run the llama inference to obtain the logits and probabilities for the next token. + /// tokens + n_tokens is the provided batch of new tokens to process + /// n_past is the number of tokens to use from previous eval calls + /// + /// + /// + /// + /// + /// + /// Returns 0 on success [DllImport(libraryName, EntryPoint = "llama_eval", CallingConvention = CallingConvention.Cdecl)] public static extern int llama_eval_with_pointer(SafeLLamaContextHandle ctx, llama_token* tokens, int n_tokens, int n_past, int n_threads); @@ -218,6 +229,7 @@ public static ulong llama_set_state_data(SafeLLamaContextHandle ctx, byte[] src) /// /// /// + /// /// /// /// @@ -256,8 +268,8 @@ public static int llama_tokenize(SafeLLamaContextHandle ctx, string text, Encodi /// /// Token logits obtained from the last call to llama_eval() /// The logits for the last token are stored in the last row - /// Can be mutated in order to change the probabilities of the next token - /// Rows: n_tokens + /// Can be mutated in order to change the probabilities of the next token.
+ /// Rows: n_tokens
/// Cols: n_vocab ///
/// diff --git a/LLama/Native/SafeLLamaContextHandle.cs b/LLama/Native/SafeLLamaContextHandle.cs index ab1022287..fa54f73ec 100644 --- a/LLama/Native/SafeLLamaContextHandle.cs +++ b/LLama/Native/SafeLLamaContextHandle.cs @@ -1,4 +1,6 @@ using System; +using System.Buffers; +using System.Text; using LLama.Exceptions; namespace LLama.Native @@ -9,11 +11,29 @@ namespace LLama.Native public class SafeLLamaContextHandle : SafeLLamaHandleBase { + #region properties and fields + /// + /// Total number of tokens in vocabulary of this model + /// + public int VocabCount => ThrowIfDisposed().VocabCount; + + /// + /// Total number of tokens in the context + /// + public int ContextSize => ThrowIfDisposed().ContextSize; + + /// + /// Dimension of embedding vectors + /// + public int EmbeddingCount => ThrowIfDisposed().EmbeddingCount; + /// /// This field guarantees that a reference to the model is held for as long as this handle is held /// private SafeLlamaModelHandle? _model; + #endregion + #region construction/destruction /// /// Create a new SafeLLamaContextHandle /// @@ -42,6 +62,16 @@ protected override bool ReleaseHandle() return true; } + private SafeLlamaModelHandle ThrowIfDisposed() + { + if (IsClosed) + throw new ObjectDisposedException("Cannot use this `SafeLLamaContextHandle` - it has been disposed"); + if (_model == null || _model.IsClosed) + throw new ObjectDisposedException("Cannot use this `SafeLLamaContextHandle` - `SafeLlamaModelHandle` has been disposed"); + + return _model; + } + /// /// Create a new llama_state for the given model /// @@ -57,5 +87,103 @@ public static SafeLLamaContextHandle Create(SafeLlamaModelHandle model, LLamaCon return new(ctx_ptr, model); } + #endregion + + /// + /// Convert the given text into tokens + /// + /// The text to tokenize + /// Whether the "BOS" token should be added + /// Encoding to use for the text + /// + /// + public int[] Tokenize(string text, bool add_bos, Encoding encoding) + { + ThrowIfDisposed(); + + // Calculate number of bytes in string, this is a pessimistic estimate of token count. It can't + // possibly be more than this. + var count = encoding.GetByteCount(text) + (add_bos ? 1 : 0); + + // "Rent" an array to write results into (avoiding an allocation of a large array) + var temporaryArray = ArrayPool.Shared.Rent(count); + try + { + // Do the actual conversion + var n = NativeApi.llama_tokenize(this, text, encoding, temporaryArray, count, add_bos); + if (n < 0) + { + throw new RuntimeError("Error happened during tokenization. It's possibly caused by wrong encoding. Please try to " + + "specify the encoding."); + } + + // Copy the results from the rented into an array which is exactly the right size + var result = new int[n]; + Array.ConstrainedCopy(temporaryArray, 0, result, 0, n); + + return result; + } + finally + { + ArrayPool.Shared.Return(temporaryArray); + } + } + + /// + /// Token logits obtained from the last call to llama_eval() + /// The logits for the last token are stored in the last row + /// Can be mutated in order to change the probabilities of the next token.
+ /// Rows: n_tokens
+ /// Cols: n_vocab + ///
+ /// + /// + public Span GetLogits() + { + var model = ThrowIfDisposed(); + + unsafe + { + var logits = NativeApi.llama_get_logits(this); + return new Span(logits, model.VocabCount); + } + } + + /// + /// Convert a token into a string + /// + /// + /// + /// + public string TokenToString(int token, Encoding encoding) + { + return ThrowIfDisposed().TokenToString(token, encoding); + } + + /// + /// Convert a token into a span of bytes that could be decoded into a string + /// + /// + /// + public ReadOnlySpan TokenToSpan(int token) + { + return ThrowIfDisposed().TokenToSpan(token); + } + + /// + /// Run the llama inference to obtain the logits and probabilities for the next token. + /// + /// The provided batch of new tokens to process + /// the number of tokens to use from previous eval calls + /// + /// Returns true on success + public bool Eval(Memory tokens, int n_past, int n_threads) + { + using var pin = tokens.Pin(); + unsafe + { + return NativeApi.llama_eval_with_pointer(this, (int*)pin.Pointer, tokens.Length, n_past, n_threads) == 0; + } + } } } diff --git a/LLama/Native/SafeLlamaModelHandle.cs b/LLama/Native/SafeLlamaModelHandle.cs index 79714fea2..dbb1b0707 100644 --- a/LLama/Native/SafeLlamaModelHandle.cs +++ b/LLama/Native/SafeLlamaModelHandle.cs @@ -13,17 +13,17 @@ public class SafeLlamaModelHandle /// /// Total number of tokens in vocabulary of this model /// - public int VocabCount { get; set; } + public int VocabCount { get; } /// /// Total number of tokens in the context /// - public int ContextSize { get; set; } + public int ContextSize { get; } /// /// Dimension of embedding vectors /// - public int EmbeddingCount { get; set; } + public int EmbeddingCount { get; } internal SafeLlamaModelHandle(IntPtr handle) : base(handle) diff --git a/LLama/Utils.cs b/LLama/Utils.cs index 391a5cc14..de363a3ed 100644 --- a/LLama/Utils.cs +++ b/LLama/Utils.cs @@ -2,10 +2,8 @@ using LLama.Native; using System; using System.Collections.Generic; -using System.Linq; using System.Runtime.InteropServices; using System.Text; -using LLama.Exceptions; using LLama.Extensions; namespace LLama @@ -27,41 +25,36 @@ public static SafeLLamaContextHandle InitLLamaContextFromModelParams(IModelParam } } + [Obsolete("Use SafeLLamaContextHandle Tokenize method instead")] public static IEnumerable Tokenize(SafeLLamaContextHandle ctx, string text, bool add_bos, Encoding encoding) { - var cnt = encoding.GetByteCount(text); - llama_token[] res = new llama_token[cnt + (add_bos ? 1 : 0)]; - int n = NativeApi.llama_tokenize(ctx, text, encoding, res, res.Length, add_bos); - if (n < 0) - { - throw new RuntimeError("Error happened during tokenization. It's possibly caused by wrong encoding. Please try to " + - "specify the encoding."); - } - return res.Take(n); + return ctx.Tokenize(text, add_bos, encoding); } - public static unsafe Span GetLogits(SafeLLamaContextHandle ctx, int length) + [Obsolete("Use SafeLLamaContextHandle GetLogits method instead")] + public static Span GetLogits(SafeLLamaContextHandle ctx, int length) { - var logits = NativeApi.llama_get_logits(ctx); - return new Span(logits, length); + if (length != ctx.VocabCount) + throw new ArgumentException("length must be the VocabSize"); + + return ctx.GetLogits(); } - public static unsafe int Eval(SafeLLamaContextHandle ctx, llama_token[] tokens, int startIndex, int n_tokens, int n_past, int n_threads) + [Obsolete("Use SafeLLamaContextHandle Eval method instead")] + public static int Eval(SafeLLamaContextHandle ctx, llama_token[] tokens, int startIndex, int n_tokens, int n_past, int n_threads) { - int result; - fixed(llama_token* p = tokens) - { - result = NativeApi.llama_eval_with_pointer(ctx, p + startIndex, n_tokens, n_past, n_threads); - } - return result; + var slice = tokens.AsMemory().Slice(startIndex, n_tokens); + return ctx.Eval(slice, n_past, n_threads) ? 0 : 1; } + [Obsolete("Use SafeLLamaContextHandle TokenToString method instead")] public static string TokenToString(llama_token token, SafeLLamaContextHandle ctx, Encoding encoding) { - return PtrToString(NativeApi.llama_token_to_str(ctx, token), encoding); + return ctx.TokenToString(token, encoding); } - public static unsafe string PtrToString(IntPtr ptr, Encoding encoding) + [Obsolete("No longer used internally by LlamaSharp")] + public static string PtrToString(IntPtr ptr, Encoding encoding) { #if NET6_0_OR_GREATER if(encoding == Encoding.UTF8) @@ -77,21 +70,24 @@ public static unsafe string PtrToString(IntPtr ptr, Encoding encoding) return Marshal.PtrToStringAuto(ptr); } #else - byte* tp = (byte*)ptr.ToPointer(); - List bytes = new(); - while (true) + unsafe { - byte c = *tp++; - if (c == '\0') - { - break; - } - else + byte* tp = (byte*)ptr.ToPointer(); + List bytes = new(); + while (true) { - bytes.Add(c); + byte c = *tp++; + if (c == '\0') + { + break; + } + else + { + bytes.Add(c); + } } + return encoding.GetString(bytes.ToArray()); } - return encoding.GetString(bytes.ToArray()); #endif } }