Skip to content

Commit

Permalink
restore model load duration on generate response (ollama#1524)
Browse files Browse the repository at this point in the history
* restore model load duration on generate response

- set model load duration on generate and chat done response
- calculate createAt time when response created

* remove checkpoints predict opts

* Update routes.go
  • Loading branch information
BruceMacD committed Dec 14, 2023
1 parent 31f0551 commit 6ee8c80
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 36 deletions.
17 changes: 4 additions & 13 deletions llm/llama.go
Original file line number Diff line number Diff line change
Expand Up @@ -548,17 +548,12 @@ const maxBufferSize = 512 * format.KiloByte
const maxRetries = 6

type PredictOpts struct {
Prompt string
Format string
Images []api.ImageData
CheckpointStart time.Time
CheckpointLoaded time.Time
Prompt string
Format string
Images []api.ImageData
}

type PredictResult struct {
CreatedAt time.Time
TotalDuration time.Duration
LoadDuration time.Duration
Content string
Done bool
PromptEvalCount int
Expand Down Expand Up @@ -681,16 +676,12 @@ func (llm *llama) Predict(ctx context.Context, predict PredictOpts, fn func(Pred

if p.Content != "" {
fn(PredictResult{
CreatedAt: time.Now().UTC(),
Content: p.Content,
Content: p.Content,
})
}

if p.Stop {
fn(PredictResult{
CreatedAt: time.Now().UTC(),
TotalDuration: time.Since(predict.CheckpointStart),

Done: true,
PromptEvalCount: p.Timings.PromptN,
PromptEvalDuration: parseDurationMs(p.Timings.PromptMS),
Expand Down
46 changes: 23 additions & 23 deletions server/routes.go
Original file line number Diff line number Diff line change
Expand Up @@ -261,38 +261,39 @@ func GenerateHandler(c *gin.Context) {

resp := api.GenerateResponse{
Model: req.Model,
CreatedAt: r.CreatedAt,
CreatedAt: time.Now().UTC(),
Done: r.Done,
Response: r.Content,
Metrics: api.Metrics{
TotalDuration: r.TotalDuration,
LoadDuration: r.LoadDuration,
PromptEvalCount: r.PromptEvalCount,
PromptEvalDuration: r.PromptEvalDuration,
EvalCount: r.EvalCount,
EvalDuration: r.EvalDuration,
},
}

if r.Done && !req.Raw {
embd, err := loaded.runner.Encode(c.Request.Context(), prompt+generated.String())
if err != nil {
ch <- gin.H{"error": err.Error()}
return
if r.Done {
resp.TotalDuration = time.Since(checkpointStart)
resp.LoadDuration = checkpointLoaded.Sub(checkpointStart)

if !req.Raw {
embd, err := loaded.runner.Encode(c.Request.Context(), prompt+generated.String())
if err != nil {
ch <- gin.H{"error": err.Error()}
return
}
resp.Context = embd
}
resp.Context = embd
}

ch <- resp
}

// Start prediction
predictReq := llm.PredictOpts{
Prompt: prompt,
Format: req.Format,
CheckpointStart: checkpointStart,
CheckpointLoaded: checkpointLoaded,
Images: req.Images,
Prompt: prompt,
Format: req.Format,
Images: req.Images,
}
if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil {
ch <- gin.H{"error": err.Error()}
Expand Down Expand Up @@ -1012,19 +1013,20 @@ func ChatHandler(c *gin.Context) {

resp := api.ChatResponse{
Model: req.Model,
CreatedAt: r.CreatedAt,
CreatedAt: time.Now().UTC(),
Done: r.Done,
Metrics: api.Metrics{
TotalDuration: r.TotalDuration,
LoadDuration: r.LoadDuration,
PromptEvalCount: r.PromptEvalCount,
PromptEvalDuration: r.PromptEvalDuration,
EvalCount: r.EvalCount,
EvalDuration: r.EvalDuration,
},
}

if !r.Done {
if r.Done {
resp.TotalDuration = time.Since(checkpointStart)
resp.LoadDuration = checkpointLoaded.Sub(checkpointStart)
} else {
resp.Message = &api.Message{Role: "assistant", Content: r.Content}
}

Expand All @@ -1033,11 +1035,9 @@ func ChatHandler(c *gin.Context) {

// Start prediction
predictReq := llm.PredictOpts{
Prompt: prompt,
Format: req.Format,
CheckpointStart: checkpointStart,
CheckpointLoaded: checkpointLoaded,
Images: images,
Prompt: prompt,
Format: req.Format,
Images: images,
}
if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil {
ch <- gin.H{"error": err.Error()}
Expand Down

0 comments on commit 6ee8c80

Please sign in to comment.