diff --git a/LLama/Extensions/IModelParamsExtensions.cs b/LLama/Extensions/IModelParamsExtensions.cs
index 83924a2f1..93b0f86ea 100644
--- a/LLama/Extensions/IModelParamsExtensions.cs
+++ b/LLama/Extensions/IModelParamsExtensions.cs
@@ -6,6 +6,9 @@
namespace LLama.Extensions
{
+ ///
+ /// Extention methods to the IModelParams interface
+ ///
public static class IModelParamsExtensions
{
///
@@ -31,7 +34,7 @@ public static MemoryHandle ToLlamaContextParams(this IModelParams @params, out L
result.n_gpu_layers = @params.GpuLayerCount;
result.seed = @params.Seed;
result.f16_kv = @params.UseFp16Memory;
- result.use_mmap = @params.UseMemoryLock;
+ result.use_mmap = @params.UseMemorymap;
result.use_mlock = @params.UseMemoryLock;
result.logits_all = @params.Perplexity;
result.embedding = @params.EmbeddingMode;
diff --git a/LLama/LLamaEmbedder.cs b/LLama/LLamaEmbedder.cs
index 24b6ee80b..a74f11ee2 100644
--- a/LLama/LLamaEmbedder.cs
+++ b/LLama/LLamaEmbedder.cs
@@ -55,7 +55,7 @@ public unsafe float[] GetEmbeddings(string text, int threads = -1, bool addBos =
text = text.Insert(0, " ");
}
- var embed_inp_array = Utils.Tokenize(_ctx, text, addBos, Encoding.GetEncoding(encoding)).ToArray();
+ var embed_inp_array = _ctx.Tokenize(text, addBos, Encoding.GetEncoding(encoding));
// TODO(Rinne): deal with log of prompt
diff --git a/LLama/LLamaInstructExecutor.cs b/LLama/LLamaInstructExecutor.cs
index e055c1475..901b347a0 100644
--- a/LLama/LLamaInstructExecutor.cs
+++ b/LLama/LLamaInstructExecutor.cs
@@ -30,8 +30,8 @@ public class InstructExecutor : StatefulExecutorBase
public InstructExecutor(LLamaModel model, string instructionPrefix = "\n\n### Instruction:\n\n",
string instructionSuffix = "\n\n### Response:\n\n") : base(model)
{
- _inp_pfx = _model.Tokenize(instructionPrefix, true).ToArray();
- _inp_sfx = _model.Tokenize(instructionSuffix, false).ToArray();
+ _inp_pfx = _model.Tokenize(instructionPrefix, true);
+ _inp_sfx = _model.Tokenize(instructionSuffix, false);
_instructionPrefix = instructionPrefix;
}
@@ -133,7 +133,7 @@ protected override void PreprocessInputs(string text, InferStateArgs args)
_embed_inps.AddRange(_inp_sfx);
- args.RemainedTokens -= line_inp.Count();
+ args.RemainedTokens -= line_inp.Length;
}
}
///
@@ -146,9 +146,7 @@ protected override bool PostProcess(IInferenceParams inferenceParams, InferState
{
string last_output = "";
foreach (var id in _last_n_tokens)
- {
- last_output += Utils.PtrToString(NativeApi.llama_token_to_str(_model.NativeHandle, id), _model.Encoding);
- }
+ last_output += _model.NativeHandle.TokenToString(id, _model.Encoding);
foreach (var antiprompt in args.Antiprompts)
{
diff --git a/LLama/LLamaInteractExecutor.cs b/LLama/LLamaInteractExecutor.cs
index f5c1583ec..fef6a4d11 100644
--- a/LLama/LLamaInteractExecutor.cs
+++ b/LLama/LLamaInteractExecutor.cs
@@ -25,7 +25,7 @@ public class InteractiveExecutor : StatefulExecutorBase
///
public InteractiveExecutor(LLamaModel model) : base(model)
{
- _llama_token_newline = Utils.Tokenize(_model.NativeHandle, "\n", false, _model.Encoding).ToArray();
+ _llama_token_newline = _model.NativeHandle.Tokenize("\n", false, _model.Encoding);
}
///
@@ -114,7 +114,7 @@ protected override void PreprocessInputs(string text, InferStateArgs args)
}
var line_inp = _model.Tokenize(text, false);
_embed_inps.AddRange(line_inp);
- args.RemainedTokens -= line_inp.Count();
+ args.RemainedTokens -= line_inp.Length;
}
}
@@ -133,7 +133,7 @@ protected override bool PostProcess(IInferenceParams inferenceParams, InferState
string last_output = "";
foreach (var id in _last_n_tokens)
{
- last_output += Utils.PtrToString(NativeApi.llama_token_to_str(_model.NativeHandle, id), _model.Encoding);
+ last_output += _model.NativeHandle.TokenToString(id, _model.Encoding);
}
foreach (var antiprompt in args.Antiprompts)
diff --git a/LLama/LLamaModel.cs b/LLama/LLamaModel.cs
index 2bd31199f..526ffcb2b 100644
--- a/LLama/LLamaModel.cs
+++ b/LLama/LLamaModel.cs
@@ -64,10 +64,9 @@ public LLamaModel(IModelParams Params, string encoding = "UTF-8", ILLamaLogger?
///
/// Whether to add a bos to the text.
///
- public IEnumerable Tokenize(string text, bool addBos = true)
+ public llama_token[] Tokenize(string text, bool addBos = true)
{
- // TODO: reconsider whether to convert to array here.
- return Utils.Tokenize(_ctx, text, addBos, _encoding);
+ return _ctx.Tokenize(text, addBos, _encoding);
}
///
@@ -79,9 +78,7 @@ public string DeTokenize(IEnumerable tokens)
{
StringBuilder sb = new();
foreach(var token in tokens)
- {
- sb.Append(Utils.PtrToString(NativeApi.llama_token_to_str(_ctx, token), _encoding));
- }
+ sb.Append(_ctx.TokenToString(token, _encoding));
return sb.ToString();
}
@@ -285,8 +282,8 @@ public LLamaTokenDataArray ApplyPenalty(IEnumerable lastTokens, Dic
int repeatLastTokensCount = 64, float repeatPenalty = 1.1f, float alphaFrequency = .0f, float alphaPresence = .0f,
bool penalizeNL = true)
{
- var n_vocab = NativeApi.llama_n_vocab(_ctx);
- var logits = Utils.GetLogits(_ctx, n_vocab);
+ var n_vocab = _ctx.VocabCount;
+ var logits = _ctx.GetLogits();
// Apply params.logit_bias map
if(logitBias is not null)
@@ -338,7 +335,7 @@ public llama_token Eval(llama_token[] tokens, llama_token pastTokensCount)
n_eval = Params.BatchSize;
}
- if(Utils.Eval(_ctx, tokens, i, n_eval, pastTokensCount, Params.Threads) != 0)
+ if (!_ctx.Eval(tokens.AsMemory(i, n_eval), pastTokensCount, Params.Threads))
{
_logger?.Log(nameof(LLamaModel), "Failed to eval.", ILLamaLogger.LogLevel.Error);
throw new RuntimeError("Failed to eval.");
@@ -353,9 +350,7 @@ public llama_token Eval(llama_token[] tokens, llama_token pastTokensCount)
internal IEnumerable GenerateResult(IEnumerable ids)
{
foreach(var id in ids)
- {
- yield return Utils.TokenToString(id, _ctx, _encoding);
- }
+ yield return _ctx.TokenToString(id, _encoding);
}
///
diff --git a/LLama/LLamaStatelessExecutor.cs b/LLama/LLamaStatelessExecutor.cs
index 06b5159d2..9075edd14 100644
--- a/LLama/LLamaStatelessExecutor.cs
+++ b/LLama/LLamaStatelessExecutor.cs
@@ -31,7 +31,7 @@ public StatelessExecutor(LLamaModel model)
_model = model;
var tokens = model.Tokenize(" ", true).ToArray();
- Utils.Eval(_model.NativeHandle, tokens, 0, tokens.Length, 0, _model.Params.Threads);
+ _model.NativeHandle.Eval(tokens.AsMemory(0, tokens.Length), 0, _model.Params.Threads);
_originalState = model.GetState();
}
@@ -52,7 +52,7 @@ public IEnumerable Infer(string text, IInferenceParams? inferenceParams
List tokens = _model.Tokenize(text, true).ToList();
int n_prompt_tokens = tokens.Count;
- Utils.Eval(_model.NativeHandle, tokens.ToArray(), 0, n_prompt_tokens, n_past, _model.Params.Threads);
+ _model.NativeHandle.Eval(tokens.ToArray().AsMemory(0, n_prompt_tokens), n_past, _model.Params.Threads);
lastTokens.AddRange(tokens);
n_past += n_prompt_tokens;
@@ -76,7 +76,7 @@ public IEnumerable Infer(string text, IInferenceParams? inferenceParams
lastTokens.Add(id);
- string response = Utils.TokenToString(id, _model.NativeHandle, _model.Encoding);
+ string response = _model.NativeHandle.TokenToString(id, _model.Encoding);
yield return response;
tokens.Clear();
@@ -87,7 +87,7 @@ public IEnumerable Infer(string text, IInferenceParams? inferenceParams
string last_output = "";
foreach (var token in lastTokens)
{
- last_output += Utils.PtrToString(NativeApi.llama_token_to_str(_model.NativeHandle, token), _model.Encoding);
+ last_output += _model.NativeHandle.TokenToString(token, _model.Encoding);
}
bool should_break = false;
diff --git a/LLama/Native/NativeApi.cs b/LLama/Native/NativeApi.cs
index 4e0ac2a29..edfb41528 100644
--- a/LLama/Native/NativeApi.cs
+++ b/LLama/Native/NativeApi.cs
@@ -207,6 +207,17 @@ public static ulong llama_set_state_data(SafeLLamaContextHandle ctx, byte[] src)
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern int llama_eval(SafeLLamaContextHandle ctx, llama_token[] tokens, int n_tokens, int n_past, int n_threads);
+ ///
+ /// Run the llama inference to obtain the logits and probabilities for the next token.
+ /// tokens + n_tokens is the provided batch of new tokens to process
+ /// n_past is the number of tokens to use from previous eval calls
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ /// Returns 0 on success
[DllImport(libraryName, EntryPoint = "llama_eval", CallingConvention = CallingConvention.Cdecl)]
public static extern int llama_eval_with_pointer(SafeLLamaContextHandle ctx, llama_token* tokens, int n_tokens, int n_past, int n_threads);
@@ -218,6 +229,7 @@ public static ulong llama_set_state_data(SafeLLamaContextHandle ctx, byte[] src)
///
///
///
+ ///
///
///
///
@@ -256,8 +268,8 @@ public static int llama_tokenize(SafeLLamaContextHandle ctx, string text, Encodi
///
/// Token logits obtained from the last call to llama_eval()
/// The logits for the last token are stored in the last row
- /// Can be mutated in order to change the probabilities of the next token
- /// Rows: n_tokens
+ /// Can be mutated in order to change the probabilities of the next token.
+ /// Rows: n_tokens
/// Cols: n_vocab
///
///
diff --git a/LLama/Native/SafeLLamaContextHandle.cs b/LLama/Native/SafeLLamaContextHandle.cs
index ab1022287..fa54f73ec 100644
--- a/LLama/Native/SafeLLamaContextHandle.cs
+++ b/LLama/Native/SafeLLamaContextHandle.cs
@@ -1,4 +1,6 @@
using System;
+using System.Buffers;
+using System.Text;
using LLama.Exceptions;
namespace LLama.Native
@@ -9,11 +11,29 @@ namespace LLama.Native
public class SafeLLamaContextHandle
: SafeLLamaHandleBase
{
+ #region properties and fields
+ ///
+ /// Total number of tokens in vocabulary of this model
+ ///
+ public int VocabCount => ThrowIfDisposed().VocabCount;
+
+ ///
+ /// Total number of tokens in the context
+ ///
+ public int ContextSize => ThrowIfDisposed().ContextSize;
+
+ ///
+ /// Dimension of embedding vectors
+ ///
+ public int EmbeddingCount => ThrowIfDisposed().EmbeddingCount;
+
///
/// This field guarantees that a reference to the model is held for as long as this handle is held
///
private SafeLlamaModelHandle? _model;
+ #endregion
+ #region construction/destruction
///
/// Create a new SafeLLamaContextHandle
///
@@ -42,6 +62,16 @@ protected override bool ReleaseHandle()
return true;
}
+ private SafeLlamaModelHandle ThrowIfDisposed()
+ {
+ if (IsClosed)
+ throw new ObjectDisposedException("Cannot use this `SafeLLamaContextHandle` - it has been disposed");
+ if (_model == null || _model.IsClosed)
+ throw new ObjectDisposedException("Cannot use this `SafeLLamaContextHandle` - `SafeLlamaModelHandle` has been disposed");
+
+ return _model;
+ }
+
///
/// Create a new llama_state for the given model
///
@@ -57,5 +87,103 @@ public static SafeLLamaContextHandle Create(SafeLlamaModelHandle model, LLamaCon
return new(ctx_ptr, model);
}
+ #endregion
+
+ ///
+ /// Convert the given text into tokens
+ ///
+ /// The text to tokenize
+ /// Whether the "BOS" token should be added
+ /// Encoding to use for the text
+ ///
+ ///
+ public int[] Tokenize(string text, bool add_bos, Encoding encoding)
+ {
+ ThrowIfDisposed();
+
+ // Calculate number of bytes in string, this is a pessimistic estimate of token count. It can't
+ // possibly be more than this.
+ var count = encoding.GetByteCount(text) + (add_bos ? 1 : 0);
+
+ // "Rent" an array to write results into (avoiding an allocation of a large array)
+ var temporaryArray = ArrayPool.Shared.Rent(count);
+ try
+ {
+ // Do the actual conversion
+ var n = NativeApi.llama_tokenize(this, text, encoding, temporaryArray, count, add_bos);
+ if (n < 0)
+ {
+ throw new RuntimeError("Error happened during tokenization. It's possibly caused by wrong encoding. Please try to " +
+ "specify the encoding.");
+ }
+
+ // Copy the results from the rented into an array which is exactly the right size
+ var result = new int[n];
+ Array.ConstrainedCopy(temporaryArray, 0, result, 0, n);
+
+ return result;
+ }
+ finally
+ {
+ ArrayPool.Shared.Return(temporaryArray);
+ }
+ }
+
+ ///
+ /// Token logits obtained from the last call to llama_eval()
+ /// The logits for the last token are stored in the last row
+ /// Can be mutated in order to change the probabilities of the next token.
+ /// Rows: n_tokens
+ /// Cols: n_vocab
+ ///
+ ///
+ ///
+ public Span GetLogits()
+ {
+ var model = ThrowIfDisposed();
+
+ unsafe
+ {
+ var logits = NativeApi.llama_get_logits(this);
+ return new Span(logits, model.VocabCount);
+ }
+ }
+
+ ///
+ /// Convert a token into a string
+ ///
+ ///
+ ///
+ ///
+ public string TokenToString(int token, Encoding encoding)
+ {
+ return ThrowIfDisposed().TokenToString(token, encoding);
+ }
+
+ ///
+ /// Convert a token into a span of bytes that could be decoded into a string
+ ///
+ ///
+ ///
+ public ReadOnlySpan TokenToSpan(int token)
+ {
+ return ThrowIfDisposed().TokenToSpan(token);
+ }
+
+ ///
+ /// Run the llama inference to obtain the logits and probabilities for the next token.
+ ///
+ /// The provided batch of new tokens to process
+ /// the number of tokens to use from previous eval calls
+ ///
+ /// Returns true on success
+ public bool Eval(Memory tokens, int n_past, int n_threads)
+ {
+ using var pin = tokens.Pin();
+ unsafe
+ {
+ return NativeApi.llama_eval_with_pointer(this, (int*)pin.Pointer, tokens.Length, n_past, n_threads) == 0;
+ }
+ }
}
}
diff --git a/LLama/Native/SafeLlamaModelHandle.cs b/LLama/Native/SafeLlamaModelHandle.cs
index 79714fea2..dbb1b0707 100644
--- a/LLama/Native/SafeLlamaModelHandle.cs
+++ b/LLama/Native/SafeLlamaModelHandle.cs
@@ -13,17 +13,17 @@ public class SafeLlamaModelHandle
///
/// Total number of tokens in vocabulary of this model
///
- public int VocabCount { get; set; }
+ public int VocabCount { get; }
///
/// Total number of tokens in the context
///
- public int ContextSize { get; set; }
+ public int ContextSize { get; }
///
/// Dimension of embedding vectors
///
- public int EmbeddingCount { get; set; }
+ public int EmbeddingCount { get; }
internal SafeLlamaModelHandle(IntPtr handle)
: base(handle)
diff --git a/LLama/Utils.cs b/LLama/Utils.cs
index 391a5cc14..de363a3ed 100644
--- a/LLama/Utils.cs
+++ b/LLama/Utils.cs
@@ -2,10 +2,8 @@
using LLama.Native;
using System;
using System.Collections.Generic;
-using System.Linq;
using System.Runtime.InteropServices;
using System.Text;
-using LLama.Exceptions;
using LLama.Extensions;
namespace LLama
@@ -27,41 +25,36 @@ public static SafeLLamaContextHandle InitLLamaContextFromModelParams(IModelParam
}
}
+ [Obsolete("Use SafeLLamaContextHandle Tokenize method instead")]
public static IEnumerable Tokenize(SafeLLamaContextHandle ctx, string text, bool add_bos, Encoding encoding)
{
- var cnt = encoding.GetByteCount(text);
- llama_token[] res = new llama_token[cnt + (add_bos ? 1 : 0)];
- int n = NativeApi.llama_tokenize(ctx, text, encoding, res, res.Length, add_bos);
- if (n < 0)
- {
- throw new RuntimeError("Error happened during tokenization. It's possibly caused by wrong encoding. Please try to " +
- "specify the encoding.");
- }
- return res.Take(n);
+ return ctx.Tokenize(text, add_bos, encoding);
}
- public static unsafe Span GetLogits(SafeLLamaContextHandle ctx, int length)
+ [Obsolete("Use SafeLLamaContextHandle GetLogits method instead")]
+ public static Span GetLogits(SafeLLamaContextHandle ctx, int length)
{
- var logits = NativeApi.llama_get_logits(ctx);
- return new Span(logits, length);
+ if (length != ctx.VocabCount)
+ throw new ArgumentException("length must be the VocabSize");
+
+ return ctx.GetLogits();
}
- public static unsafe int Eval(SafeLLamaContextHandle ctx, llama_token[] tokens, int startIndex, int n_tokens, int n_past, int n_threads)
+ [Obsolete("Use SafeLLamaContextHandle Eval method instead")]
+ public static int Eval(SafeLLamaContextHandle ctx, llama_token[] tokens, int startIndex, int n_tokens, int n_past, int n_threads)
{
- int result;
- fixed(llama_token* p = tokens)
- {
- result = NativeApi.llama_eval_with_pointer(ctx, p + startIndex, n_tokens, n_past, n_threads);
- }
- return result;
+ var slice = tokens.AsMemory().Slice(startIndex, n_tokens);
+ return ctx.Eval(slice, n_past, n_threads) ? 0 : 1;
}
+ [Obsolete("Use SafeLLamaContextHandle TokenToString method instead")]
public static string TokenToString(llama_token token, SafeLLamaContextHandle ctx, Encoding encoding)
{
- return PtrToString(NativeApi.llama_token_to_str(ctx, token), encoding);
+ return ctx.TokenToString(token, encoding);
}
- public static unsafe string PtrToString(IntPtr ptr, Encoding encoding)
+ [Obsolete("No longer used internally by LlamaSharp")]
+ public static string PtrToString(IntPtr ptr, Encoding encoding)
{
#if NET6_0_OR_GREATER
if(encoding == Encoding.UTF8)
@@ -77,21 +70,24 @@ public static unsafe string PtrToString(IntPtr ptr, Encoding encoding)
return Marshal.PtrToStringAuto(ptr);
}
#else
- byte* tp = (byte*)ptr.ToPointer();
- List bytes = new();
- while (true)
+ unsafe
{
- byte c = *tp++;
- if (c == '\0')
- {
- break;
- }
- else
+ byte* tp = (byte*)ptr.ToPointer();
+ List bytes = new();
+ while (true)
{
- bytes.Add(c);
+ byte c = *tp++;
+ if (c == '\0')
+ {
+ break;
+ }
+ else
+ {
+ bytes.Add(c);
+ }
}
+ return encoding.GetString(bytes.ToArray());
}
- return encoding.GetString(bytes.ToArray());
#endif
}
}