Skip to content

Commit

Permalink
some prediction refactoring, kv_shift from llama.cpp example
Browse files Browse the repository at this point in the history
  • Loading branch information
guinmoon committed May 17, 2024
1 parent 24cd85d commit 2eed38b
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 41 deletions.
44 changes: 27 additions & 17 deletions Sources/llmfarm_core/LLMBase.swift
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ public class LLMBase {
public var evalCallback: ((Int) -> (Bool))? = nil
public var evalDebugCallback: ((String) -> (Bool))? = nil
public var modelPath: String
public var outputRepeatTokens: [ModelToken] = []

// Used to keep old context until it needs to be rotated or purge out for new tokens
var past: [[ModelToken]] = [] // Will house both queries and responses in order
Expand Down Expand Up @@ -329,9 +330,7 @@ public class LLMBase {
}



public func predict(_ input: String, _ callback: ((String, Double) -> Bool),system_prompt:String? = nil,img_path: String? = nil ) throws -> String {
let params = sampleParams
public func _eval_system_prompt(system_prompt:String? = nil) throws{
if system_prompt != nil{
var system_pormpt_Tokens = tokenizePrompt(system_prompt ?? "", .None)
var eval_res:Bool? = nil
Expand All @@ -343,6 +342,9 @@ public class LLMBase {
}
self.nPast += Int32(system_pormpt_Tokens.count)
}
}

public func _eval_img(img_path:String? = nil) throws{
if img_path != nil{
do {
try ExceptionCather.catchException {
Expand All @@ -356,6 +358,22 @@ public class LLMBase {
throw error
}
}
}

public func kv_shift() throws{
self.nPast = self.nPast / 2
try ExceptionCather.catchException {
_ = try? self.llm_eval(inputBatch: [self.llm_token_eos()])
}
print("Context Limit!")
}

public func predict(_ input: String, _ callback: ((String, Double) -> Bool),system_prompt:String? = nil,img_path: String? = nil ) throws -> String {
let params = sampleParams

try _eval_system_prompt(system_prompt:system_prompt)
try _eval_img(img_path:img_path)

let contextLength = Int32(contextParams.context)
print("Past token count: \(nPast)/\(contextLength) (\(past.count))")
// Tokenize with prompt format
Expand Down Expand Up @@ -385,11 +403,8 @@ public class LLMBase {
inputTokens.removeFirst(evalCount)

if self.nPast + Int32(inputBatch.count) >= self.contextParams.context{
self.nPast = 0
try ExceptionCather.catchException {
_ = try? self.llm_eval(inputBatch: [self.llm_token_eos()])
}
// throw ModelError.contextLimit
try self.kv_shift()
callback("**C_LIMIT**",0)
}
var eval_res:Bool? = nil
try ExceptionCather.catchException {
Expand All @@ -401,7 +416,7 @@ public class LLMBase {
nPast += Int32(evalCount)
}
// Output
var outputRepeatTokens: [ModelToken] = []
outputRepeatTokens = []
var outputTokens: [ModelToken] = []
var output = [String]()
// Loop until target count is reached
Expand Down Expand Up @@ -467,14 +482,9 @@ public class LLMBase {
if completion_loop {
// Send generated token back into model for next generation
var eval_res:Bool? = nil
if self.nPast >= self.contextParams.context - 4{
self.nPast = self.nPast / 2
outputToken = self.llm_token_eos()
try ExceptionCather.catchException {
_ = try? self.llm_eval(inputBatch: [outputToken])
}
print("Context Limit!")
// throw ModelError.contextLimit
if self.nPast >= self.contextParams.context - 2{
try self.kv_shift()
callback("**C_LIMIT**",0)
}
try ExceptionCather.catchException {
eval_res = try? self.llm_eval(inputBatch: [outputToken])
Expand Down
47 changes: 23 additions & 24 deletions Sources/llmfarm_core/LLaMa.swift
Original file line number Diff line number Diff line change
Expand Up @@ -135,12 +135,14 @@ public class LLaMa: LLMBase {
if self.context == nil {
return false
}
// var tokens_tmp: [llama_token] = [Int32](repeating: 0, count: 100000)
// var tokens_count:Int = 0
// llama_load_session_file(self.context,"/Users/guinmoon/Library/Containers/com.guinmoon.LLMFarm/Data/Documents/models/dump_state.bin",tokens_tmp.mutPtr, 100000,&tokens_count)
// self.session_tokens.append(contentsOf: tokens_tmp[0..<tokens_count])
// try? llm_eval(inputBatch:self.session_tokens)
// llama_load_state(self.context,"/Users/guinmoon/Library/Containers/com.guinmoon.LLMFarm/Data/Documents/models/dump_state_.bin")

// var tokens_tmp: [llama_token] = [Int32](repeating: 0, count: 4096)
// var tokens_count:Int = 0
// llama_state_load_file(self.context,"/Users/guinmoon/Library/Containers/com.guinmoon.LLMFarm/Data/Documents/models/dump_state.bin",tokens_tmp.mutPtr, 4096,&tokens_count)
// self.outputRepeatTokens.append(contentsOf: tokens_tmp[0..<tokens_count-1])
// self.nPast = tokens_tmp[tokens_count-1]


if !load_clip_model(){
return false
}
Expand Down Expand Up @@ -175,7 +177,8 @@ public class LLaMa: LLMBase {

deinit {
// llama_save_state(self.context,"/Users/guinmoon/Library/Containers/com.guinmoon.LLMFarm/Data/Documents/models/dump_state_.bin")
// llama_save_session_file(self.context,"/Users/guinmoon/Library/Containers/com.guinmoon.LLMFarm/Data/Documents/models/dump_state.bin",self.session_tokens, self.session_tokens.count)
// self.outputRepeatTokens.append(self.nPast)
// llama_state_save_file(self.context,"/Users/guinmoon/Library/Containers/com.guinmoon.LLMFarm/Data/Documents/models/dump_state.bin",self.outputRepeatTokens, self.outputRepeatTokens.count)
print("deinit LLaMa")
self.destroy_objects()
print("LLaMa deinited")
Expand Down Expand Up @@ -223,26 +226,23 @@ public class LLaMa: LLMBase {
}
return true
}

public override func kv_shift() throws{
let n_discard = self.nPast/2
llama_kv_cache_seq_rm (context, 0, 0 , n_discard);
llama_kv_cache_seq_add(context, 0, n_discard, self.nPast, -n_discard);
self.nPast -= n_discard;
try ExceptionCather.catchException {
_ = try? self.llm_eval(inputBatch: [self.llm_token_eos()])
}
self.nPast+=1
self.outputRepeatTokens = []
print("Context Limit!")
}

func completion_init(tokens_list: [ModelToken]) {
// print("attempting to complete \"\(text)\"")

// tokens_list = tokenize(text: text, add_bos: true)
temporary_invalid_cchars = []

// let n_ctx = llama_n_ctx(context)
// let n_kv_req = tokens_list.count + (Int(n_len) - tokens_list.count)
//
// print("\n n_len = \(n_len), n_ctx = \(n_ctx), n_kv_req = \(n_kv_req)")
//
// if n_kv_req > n_ctx {
// print("error: n_kv_req > n_ctx, the required KV cache size is not big enough")
// }

// for id in tokens_list {
// print(String(cString: token_to_piece(token: id) + [0]))
// }

llama_batch_clear(&batch!)

for i1 in 0..<tokens_list.count {
Expand All @@ -255,7 +255,6 @@ public class LLaMa: LLMBase {
print("llama_decode() failed")
}

// n_cur = batch.n_tokens
}


Expand Down

0 comments on commit 2eed38b

Please sign in to comment.