Skip to content

Commit

Permalink
Merge branch 'SciSharp:master' into master
Browse files Browse the repository at this point in the history
  • Loading branch information
SignalRT authored Aug 8, 2023
2 parents a8f2538 + f612275 commit 115cdb7
Show file tree
Hide file tree
Showing 10 changed files with 198 additions and 66 deletions.
5 changes: 4 additions & 1 deletion LLama/Extensions/IModelParamsExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@

namespace LLama.Extensions
{
/// <summary>
/// Extention methods to the IModelParams interface
/// </summary>
public static class IModelParamsExtensions
{
/// <summary>
Expand All @@ -31,7 +34,7 @@ public static MemoryHandle ToLlamaContextParams(this IModelParams @params, out L
result.n_gpu_layers = @params.GpuLayerCount;
result.seed = @params.Seed;
result.f16_kv = @params.UseFp16Memory;
result.use_mmap = @params.UseMemoryLock;
result.use_mmap = @params.UseMemorymap;
result.use_mlock = @params.UseMemoryLock;
result.logits_all = @params.Perplexity;
result.embedding = @params.EmbeddingMode;
Expand Down
2 changes: 1 addition & 1 deletion LLama/LLamaEmbedder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ public unsafe float[] GetEmbeddings(string text, int threads = -1, bool addBos =
text = text.Insert(0, " ");
}

var embed_inp_array = Utils.Tokenize(_ctx, text, addBos, Encoding.GetEncoding(encoding)).ToArray();
var embed_inp_array = _ctx.Tokenize(text, addBos, Encoding.GetEncoding(encoding));

// TODO(Rinne): deal with log of prompt

Expand Down
10 changes: 4 additions & 6 deletions LLama/LLamaInstructExecutor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ public class InstructExecutor : StatefulExecutorBase
public InstructExecutor(LLamaModel model, string instructionPrefix = "\n\n### Instruction:\n\n",
string instructionSuffix = "\n\n### Response:\n\n") : base(model)
{
_inp_pfx = _model.Tokenize(instructionPrefix, true).ToArray();
_inp_sfx = _model.Tokenize(instructionSuffix, false).ToArray();
_inp_pfx = _model.Tokenize(instructionPrefix, true);
_inp_sfx = _model.Tokenize(instructionSuffix, false);
_instructionPrefix = instructionPrefix;
}

Expand Down Expand Up @@ -133,7 +133,7 @@ protected override void PreprocessInputs(string text, InferStateArgs args)

_embed_inps.AddRange(_inp_sfx);

args.RemainedTokens -= line_inp.Count();
args.RemainedTokens -= line_inp.Length;
}
}
/// <inheritdoc />
Expand All @@ -146,9 +146,7 @@ protected override bool PostProcess(IInferenceParams inferenceParams, InferState
{
string last_output = "";
foreach (var id in _last_n_tokens)
{
last_output += Utils.PtrToString(NativeApi.llama_token_to_str(_model.NativeHandle, id), _model.Encoding);
}
last_output += _model.NativeHandle.TokenToString(id, _model.Encoding);

foreach (var antiprompt in args.Antiprompts)
{
Expand Down
6 changes: 3 additions & 3 deletions LLama/LLamaInteractExecutor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ public class InteractiveExecutor : StatefulExecutorBase
/// <param name="model"></param>
public InteractiveExecutor(LLamaModel model) : base(model)
{
_llama_token_newline = Utils.Tokenize(_model.NativeHandle, "\n", false, _model.Encoding).ToArray();
_llama_token_newline = _model.NativeHandle.Tokenize("\n", false, _model.Encoding);
}

/// <inheritdoc />
Expand Down Expand Up @@ -114,7 +114,7 @@ protected override void PreprocessInputs(string text, InferStateArgs args)
}
var line_inp = _model.Tokenize(text, false);
_embed_inps.AddRange(line_inp);
args.RemainedTokens -= line_inp.Count();
args.RemainedTokens -= line_inp.Length;
}
}

Expand All @@ -133,7 +133,7 @@ protected override bool PostProcess(IInferenceParams inferenceParams, InferState
string last_output = "";
foreach (var id in _last_n_tokens)
{
last_output += Utils.PtrToString(NativeApi.llama_token_to_str(_model.NativeHandle, id), _model.Encoding);
last_output += _model.NativeHandle.TokenToString(id, _model.Encoding);
}

foreach (var antiprompt in args.Antiprompts)
Expand Down
19 changes: 7 additions & 12 deletions LLama/LLamaModel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,9 @@ public LLamaModel(IModelParams Params, string encoding = "UTF-8", ILLamaLogger?
/// <param name="text"></param>
/// <param name="addBos">Whether to add a bos to the text.</param>
/// <returns></returns>
public IEnumerable<llama_token> Tokenize(string text, bool addBos = true)
public llama_token[] Tokenize(string text, bool addBos = true)
{
// TODO: reconsider whether to convert to array here.
return Utils.Tokenize(_ctx, text, addBos, _encoding);
return _ctx.Tokenize(text, addBos, _encoding);
}

/// <summary>
Expand All @@ -79,9 +78,7 @@ public string DeTokenize(IEnumerable<llama_token> tokens)
{
StringBuilder sb = new();
foreach(var token in tokens)
{
sb.Append(Utils.PtrToString(NativeApi.llama_token_to_str(_ctx, token), _encoding));
}
sb.Append(_ctx.TokenToString(token, _encoding));
return sb.ToString();
}

Expand Down Expand Up @@ -285,8 +282,8 @@ public LLamaTokenDataArray ApplyPenalty(IEnumerable<llama_token> lastTokens, Dic
int repeatLastTokensCount = 64, float repeatPenalty = 1.1f, float alphaFrequency = .0f, float alphaPresence = .0f,
bool penalizeNL = true)
{
var n_vocab = NativeApi.llama_n_vocab(_ctx);
var logits = Utils.GetLogits(_ctx, n_vocab);
var n_vocab = _ctx.VocabCount;
var logits = _ctx.GetLogits();

// Apply params.logit_bias map
if(logitBias is not null)
Expand Down Expand Up @@ -338,7 +335,7 @@ public llama_token Eval(llama_token[] tokens, llama_token pastTokensCount)
n_eval = Params.BatchSize;
}

if(Utils.Eval(_ctx, tokens, i, n_eval, pastTokensCount, Params.Threads) != 0)
if (!_ctx.Eval(tokens.AsMemory(i, n_eval), pastTokensCount, Params.Threads))
{
_logger?.Log(nameof(LLamaModel), "Failed to eval.", ILLamaLogger.LogLevel.Error);
throw new RuntimeError("Failed to eval.");
Expand All @@ -353,9 +350,7 @@ public llama_token Eval(llama_token[] tokens, llama_token pastTokensCount)
internal IEnumerable<string> GenerateResult(IEnumerable<llama_token> ids)
{
foreach(var id in ids)
{
yield return Utils.TokenToString(id, _ctx, _encoding);
}
yield return _ctx.TokenToString(id, _encoding);
}

/// <inheritdoc />
Expand Down
8 changes: 4 additions & 4 deletions LLama/LLamaStatelessExecutor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ public StatelessExecutor(LLamaModel model)
_model = model;

var tokens = model.Tokenize(" ", true).ToArray();
Utils.Eval(_model.NativeHandle, tokens, 0, tokens.Length, 0, _model.Params.Threads);
_model.NativeHandle.Eval(tokens.AsMemory(0, tokens.Length), 0, _model.Params.Threads);
_originalState = model.GetState();
}

Expand All @@ -52,7 +52,7 @@ public IEnumerable<string> Infer(string text, IInferenceParams? inferenceParams
List<llama_token> tokens = _model.Tokenize(text, true).ToList();
int n_prompt_tokens = tokens.Count;

Utils.Eval(_model.NativeHandle, tokens.ToArray(), 0, n_prompt_tokens, n_past, _model.Params.Threads);
_model.NativeHandle.Eval(tokens.ToArray().AsMemory(0, n_prompt_tokens), n_past, _model.Params.Threads);

lastTokens.AddRange(tokens);
n_past += n_prompt_tokens;
Expand All @@ -76,7 +76,7 @@ public IEnumerable<string> Infer(string text, IInferenceParams? inferenceParams

lastTokens.Add(id);

string response = Utils.TokenToString(id, _model.NativeHandle, _model.Encoding);
string response = _model.NativeHandle.TokenToString(id, _model.Encoding);
yield return response;

tokens.Clear();
Expand All @@ -87,7 +87,7 @@ public IEnumerable<string> Infer(string text, IInferenceParams? inferenceParams
string last_output = "";
foreach (var token in lastTokens)
{
last_output += Utils.PtrToString(NativeApi.llama_token_to_str(_model.NativeHandle, token), _model.Encoding);
last_output += _model.NativeHandle.TokenToString(token, _model.Encoding);
}

bool should_break = false;
Expand Down
16 changes: 14 additions & 2 deletions LLama/Native/NativeApi.cs
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,17 @@ public static ulong llama_set_state_data(SafeLLamaContextHandle ctx, byte[] src)
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern int llama_eval(SafeLLamaContextHandle ctx, llama_token[] tokens, int n_tokens, int n_past, int n_threads);

/// <summary>
/// Run the llama inference to obtain the logits and probabilities for the next token.
/// tokens + n_tokens is the provided batch of new tokens to process
/// n_past is the number of tokens to use from previous eval calls
/// </summary>
/// <param name="ctx"></param>
/// <param name="tokens"></param>
/// <param name="n_tokens"></param>
/// <param name="n_past"></param>
/// <param name="n_threads"></param>
/// <returns>Returns 0 on success</returns>
[DllImport(libraryName, EntryPoint = "llama_eval", CallingConvention = CallingConvention.Cdecl)]
public static extern int llama_eval_with_pointer(SafeLLamaContextHandle ctx, llama_token* tokens, int n_tokens, int n_past, int n_threads);

Expand All @@ -218,6 +229,7 @@ public static ulong llama_set_state_data(SafeLLamaContextHandle ctx, byte[] src)
/// </summary>
/// <param name="ctx"></param>
/// <param name="text"></param>
/// <param name="encoding"></param>
/// <param name="tokens"></param>
/// <param name="n_max_tokens"></param>
/// <param name="add_bos"></param>
Expand Down Expand Up @@ -256,8 +268,8 @@ public static int llama_tokenize(SafeLLamaContextHandle ctx, string text, Encodi
/// <summary>
/// Token logits obtained from the last call to llama_eval()
/// The logits for the last token are stored in the last row
/// Can be mutated in order to change the probabilities of the next token
/// Rows: n_tokens
/// Can be mutated in order to change the probabilities of the next token.<br />
/// Rows: n_tokens<br />
/// Cols: n_vocab
/// </summary>
/// <param name="ctx"></param>
Expand Down
128 changes: 128 additions & 0 deletions LLama/Native/SafeLLamaContextHandle.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
using System;
using System.Buffers;
using System.Text;
using LLama.Exceptions;

namespace LLama.Native
Expand All @@ -9,11 +11,29 @@ namespace LLama.Native
public class SafeLLamaContextHandle
: SafeLLamaHandleBase
{
#region properties and fields
/// <summary>
/// Total number of tokens in vocabulary of this model
/// </summary>
public int VocabCount => ThrowIfDisposed().VocabCount;

/// <summary>
/// Total number of tokens in the context
/// </summary>
public int ContextSize => ThrowIfDisposed().ContextSize;

/// <summary>
/// Dimension of embedding vectors
/// </summary>
public int EmbeddingCount => ThrowIfDisposed().EmbeddingCount;

/// <summary>
/// This field guarantees that a reference to the model is held for as long as this handle is held
/// </summary>
private SafeLlamaModelHandle? _model;
#endregion

#region construction/destruction
/// <summary>
/// Create a new SafeLLamaContextHandle
/// </summary>
Expand Down Expand Up @@ -42,6 +62,16 @@ protected override bool ReleaseHandle()
return true;
}

private SafeLlamaModelHandle ThrowIfDisposed()
{
if (IsClosed)
throw new ObjectDisposedException("Cannot use this `SafeLLamaContextHandle` - it has been disposed");
if (_model == null || _model.IsClosed)
throw new ObjectDisposedException("Cannot use this `SafeLLamaContextHandle` - `SafeLlamaModelHandle` has been disposed");

return _model;
}

/// <summary>
/// Create a new llama_state for the given model
/// </summary>
Expand All @@ -57,5 +87,103 @@ public static SafeLLamaContextHandle Create(SafeLlamaModelHandle model, LLamaCon

return new(ctx_ptr, model);
}
#endregion

/// <summary>
/// Convert the given text into tokens
/// </summary>
/// <param name="text">The text to tokenize</param>
/// <param name="add_bos">Whether the "BOS" token should be added</param>
/// <param name="encoding">Encoding to use for the text</param>
/// <returns></returns>
/// <exception cref="RuntimeError"></exception>
public int[] Tokenize(string text, bool add_bos, Encoding encoding)
{
ThrowIfDisposed();

// Calculate number of bytes in string, this is a pessimistic estimate of token count. It can't
// possibly be more than this.
var count = encoding.GetByteCount(text) + (add_bos ? 1 : 0);

// "Rent" an array to write results into (avoiding an allocation of a large array)
var temporaryArray = ArrayPool<int>.Shared.Rent(count);
try
{
// Do the actual conversion
var n = NativeApi.llama_tokenize(this, text, encoding, temporaryArray, count, add_bos);
if (n < 0)
{
throw new RuntimeError("Error happened during tokenization. It's possibly caused by wrong encoding. Please try to " +
"specify the encoding.");
}

// Copy the results from the rented into an array which is exactly the right size
var result = new int[n];
Array.ConstrainedCopy(temporaryArray, 0, result, 0, n);

return result;
}
finally
{
ArrayPool<int>.Shared.Return(temporaryArray);
}
}

/// <summary>
/// Token logits obtained from the last call to llama_eval()
/// The logits for the last token are stored in the last row
/// Can be mutated in order to change the probabilities of the next token.<br />
/// Rows: n_tokens<br />
/// Cols: n_vocab
/// </summary>
/// <param name="ctx"></param>
/// <returns></returns>
public Span<float> GetLogits()
{
var model = ThrowIfDisposed();

unsafe
{
var logits = NativeApi.llama_get_logits(this);
return new Span<float>(logits, model.VocabCount);
}
}

/// <summary>
/// Convert a token into a string
/// </summary>
/// <param name="token"></param>
/// <param name="encoding"></param>
/// <returns></returns>
public string TokenToString(int token, Encoding encoding)
{
return ThrowIfDisposed().TokenToString(token, encoding);
}

/// <summary>
/// Convert a token into a span of bytes that could be decoded into a string
/// </summary>
/// <param name="token"></param>
/// <returns></returns>
public ReadOnlySpan<byte> TokenToSpan(int token)
{
return ThrowIfDisposed().TokenToSpan(token);
}

/// <summary>
/// Run the llama inference to obtain the logits and probabilities for the next token.
/// </summary>
/// <param name="tokens">The provided batch of new tokens to process</param>
/// <param name="n_past">the number of tokens to use from previous eval calls</param>
/// <param name="n_threads"></param>
/// <returns>Returns true on success</returns>
public bool Eval(Memory<int> tokens, int n_past, int n_threads)
{
using var pin = tokens.Pin();
unsafe
{
return NativeApi.llama_eval_with_pointer(this, (int*)pin.Pointer, tokens.Length, n_past, n_threads) == 0;
}
}
}
}
6 changes: 3 additions & 3 deletions LLama/Native/SafeLlamaModelHandle.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,17 @@ public class SafeLlamaModelHandle
/// <summary>
/// Total number of tokens in vocabulary of this model
/// </summary>
public int VocabCount { get; set; }
public int VocabCount { get; }

/// <summary>
/// Total number of tokens in the context
/// </summary>
public int ContextSize { get; set; }
public int ContextSize { get; }

/// <summary>
/// Dimension of embedding vectors
/// </summary>
public int EmbeddingCount { get; set; }
public int EmbeddingCount { get; }

internal SafeLlamaModelHandle(IntPtr handle)
: base(handle)
Expand Down
Loading

0 comments on commit 115cdb7

Please sign in to comment.