Skip to content

Commit

Permalink
make rms_norm_eps a parameter (ggerganov#2374)
Browse files Browse the repository at this point in the history
* make rms_norm_eps a parameter

* add rms_norm_eps to command line

* fix baby llama, test-grad0

* use scientific notation for eps param in the help

ggml-ci
  • Loading branch information
slaren committed Jul 24, 2023
1 parent b3f138d commit 41c6741
Show file tree
Hide file tree
Showing 11 changed files with 89 additions and 56 deletions.
20 changes: 11 additions & 9 deletions examples/baby-llama/baby-llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
#pragma warning(disable: 4244 4267) // possible loss of data
#endif

static const float rms_norm_eps = 1e-6f;

float frand() {
return (float)rand()/(float)RAND_MAX;
}
Expand Down Expand Up @@ -562,7 +564,7 @@ struct ggml_tensor * forward(
// norm
{
// cur shape [n_embd,N,1,1]
cur = ggml_rms_norm(ctx0, inpL);
cur = ggml_rms_norm(ctx0, inpL, rms_norm_eps);

// cur = attention_norm*cur
cur = ggml_mul(ctx0,
Expand Down Expand Up @@ -685,7 +687,7 @@ struct ggml_tensor * forward(
// norm
{
// cur shape [n_embd,N,1,1]
cur = ggml_rms_norm(ctx0, inpFF);
cur = ggml_rms_norm(ctx0, inpFF, rms_norm_eps);

// cur = ffn_norm*cur
// cur shape [n_embd,N,1,1]
Expand Down Expand Up @@ -729,7 +731,7 @@ struct ggml_tensor * forward(
{

// inpL shape [n_embd,N,1,1]
inpL = ggml_rms_norm(ctx0, inpL);
inpL = ggml_rms_norm(ctx0, inpL, rms_norm_eps);

// inpL = norm*inpL
// inpL shape [n_embd,N,1,1]
Expand Down Expand Up @@ -817,7 +819,7 @@ struct ggml_tensor * forward_batch(
// norm
{
// cur shape [n_embd,N*n_batch,1,1]
cur = ggml_rms_norm(ctx0, inpL);
cur = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
assert_shape_2d(cur, n_embd, N*n_batch);

// cur = attention_norm*cur
Expand Down Expand Up @@ -981,7 +983,7 @@ struct ggml_tensor * forward_batch(
// norm
{
// cur shape [n_embd,N*n_batch,1,1]
cur = ggml_rms_norm(ctx0, inpFF);
cur = ggml_rms_norm(ctx0, inpFF, rms_norm_eps);
assert_shape_2d(cur, n_embd, N*n_batch);

// cur = ffn_norm*cur
Expand Down Expand Up @@ -1034,7 +1036,7 @@ struct ggml_tensor * forward_batch(
{

// inpL shape [n_embd,N*n_batch,1,1]
inpL = ggml_rms_norm(ctx0, inpL);
inpL = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
assert_shape_2d(inpL, n_embd, N*n_batch);

// inpL = norm*inpL
Expand Down Expand Up @@ -1104,7 +1106,7 @@ struct ggml_tensor * forward_lora(
// norm
{
// cur shape [n_embd,N,1,1]
cur = ggml_rms_norm(ctx0, inpL);
cur = ggml_rms_norm(ctx0, inpL, rms_norm_eps);

// cur = attention_norm*cur
cur = ggml_mul(ctx0,
Expand Down Expand Up @@ -1251,7 +1253,7 @@ struct ggml_tensor * forward_lora(
// norm
{
// cur shape [n_embd,N,1,1]
cur = ggml_rms_norm(ctx0, inpFF);
cur = ggml_rms_norm(ctx0, inpFF, rms_norm_eps);

// cur = ffn_norm*cur
// cur shape [n_embd,N,1,1]
Expand Down Expand Up @@ -1295,7 +1297,7 @@ struct ggml_tensor * forward_lora(
{

// inpL shape [n_embd,N,1,1]
inpL = ggml_rms_norm(ctx0, inpL);
inpL = ggml_rms_norm(ctx0, inpL, rms_norm_eps);

// inpL = norm*inpL
// inpL shape [n_embd,N,1,1]
Expand Down
8 changes: 8 additions & 0 deletions examples/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,12 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
break;
}
params.n_gqa = std::stoi(argv[i]);
} else if (arg == "-eps" || arg == "--rms-norm-eps") {
if (++i >= argc) {
invalid_param = true;
break;
}
params.rms_norm_eps = std::stof(argv[i]);
} else if (arg == "--rope-freq-base") {
if (++i >= argc) {
invalid_param = true;
Expand Down Expand Up @@ -519,6 +525,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
fprintf(stdout, " -c N, --ctx-size N size of the prompt context (default: %d)\n", params.n_ctx);
fprintf(stdout, " -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch);
fprintf(stdout, " -gqa N, --gqa N grouped-query attention factor (TEMP!!! use 8 for LLaMAv2 70B) (default: %d)\n", params.n_gqa);
fprintf(stdout, " -eps N, --rms-norm-eps N rms norm eps (TEMP!!! use 1e-5 for LLaMAv2) (default: %.1e)\n", params.rms_norm_eps);
fprintf(stdout, " --top-k N top-k sampling (default: %d, 0 = disabled)\n", params.top_k);
fprintf(stdout, " --top-p N top-p sampling (default: %.1f, 1.0 = disabled)\n", (double)params.top_p);
fprintf(stdout, " --tfs N tail free sampling, parameter z (default: %.1f, 1.0 = disabled)\n", (double)params.tfs_z);
Expand Down Expand Up @@ -615,6 +622,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
lparams.n_ctx = params.n_ctx;
lparams.n_batch = params.n_batch;
lparams.n_gqa = params.n_gqa;
lparams.rms_norm_eps = params.rms_norm_eps;
lparams.n_gpu_layers = params.n_gpu_layers;
lparams.main_gpu = params.main_gpu;
lparams.tensor_split = params.tensor_split;
Expand Down
23 changes: 12 additions & 11 deletions examples/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,19 @@
int32_t get_num_physical_cores();

struct gpt_params {
uint32_t seed = -1; // RNG seed
uint32_t seed = -1; // RNG seed
int32_t n_threads = get_num_physical_cores();
int32_t n_predict = -1; // new tokens to predict
int32_t n_ctx = 512; // context size
int32_t n_batch = 512; // batch size for prompt processing (must be >=32 to use BLAS)
int32_t n_gqa = 1; // grouped-query attention factor (TODO: move to hparams)
int32_t n_keep = 0; // number of tokens to keep from initial prompt
int32_t n_chunks = -1; // max number of chunks to process (-1 = unlimited)
int32_t n_gpu_layers = 0; // number of layers to store in VRAM
int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors
float tensor_split[LLAMA_MAX_DEVICES] = {0}; // how split tensors should be distributed across GPUs
int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens.
int32_t n_predict = -1; // new tokens to predict
int32_t n_ctx = 512; // context size
int32_t n_batch = 512; // batch size for prompt processing (must be >=32 to use BLAS)
int32_t n_gqa = 1; // grouped-query attention factor (TODO: move to hparams)
int32_t n_keep = 0; // number of tokens to keep from initial prompt
int32_t n_chunks = -1; // max number of chunks to process (-1 = unlimited)
int32_t n_gpu_layers = 0; // number of layers to store in VRAM
int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors
float tensor_split[LLAMA_MAX_DEVICES] = {0}; // how split tensors should be distributed across GPUs
int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens.
float rms_norm_eps = 1e-6; // rms norm epsilon
float rope_freq_base = 10000.0f; // RoPE base frequency
float rope_freq_scale = 1.0f; // RoPE frequency scaling factor

Expand Down
32 changes: 17 additions & 15 deletions examples/train-text-from-scratch/train-text-from-scratch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
#pragma warning(disable: 4244 4267) // possible loss of data
#endif

static const float rms_norm_eps = 1e-6f;

struct random_normal_distribution {
std::mt19937 gen;
std::normal_distribution<float> rd;
Expand Down Expand Up @@ -439,7 +441,7 @@ struct ggml_tensor * forward(
// norm
{
// cur shape [n_embd,N,1,1]
cur = ggml_rms_norm(ctx0, inpL);
cur = ggml_rms_norm(ctx0, inpL, rms_norm_eps);

// cur = attention_norm*cur
cur = ggml_mul(ctx0,
Expand Down Expand Up @@ -562,7 +564,7 @@ struct ggml_tensor * forward(
// norm
{
// cur shape [n_embd,N,1,1]
cur = ggml_rms_norm(ctx0, inpFF);
cur = ggml_rms_norm(ctx0, inpFF, rms_norm_eps);

// cur = ffn_norm*cur
// cur shape [n_embd,N,1,1]
Expand Down Expand Up @@ -606,7 +608,7 @@ struct ggml_tensor * forward(
{

// inpL shape [n_embd,N,1,1]
inpL = ggml_rms_norm(ctx0, inpL);
inpL = ggml_rms_norm(ctx0, inpL, rms_norm_eps);

// inpL = norm*inpL
// inpL shape [n_embd,N,1,1]
Expand Down Expand Up @@ -694,7 +696,7 @@ struct ggml_tensor * forward_batch(
// norm
{
// cur shape [n_embd,N*n_batch,1,1]
cur = ggml_rms_norm(ctx0, inpL);
cur = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
assert_shape_2d(cur, n_embd, N*n_batch);

// cur = attention_norm*cur
Expand Down Expand Up @@ -857,7 +859,7 @@ struct ggml_tensor * forward_batch(
// norm
{
// cur shape [n_embd,N*n_batch,1,1]
cur = ggml_rms_norm(ctx0, inpFF);
cur = ggml_rms_norm(ctx0, inpFF, rms_norm_eps);
assert_shape_2d(cur, n_embd, N*n_batch);

// cur = ffn_norm*cur
Expand Down Expand Up @@ -910,7 +912,7 @@ struct ggml_tensor * forward_batch(
{

// inpL shape [n_embd,N*n_batch,1,1]
inpL = ggml_rms_norm(ctx0, inpL);
inpL = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
assert_shape_2d(inpL, n_embd, N*n_batch);

// inpL = norm*inpL
Expand Down Expand Up @@ -979,7 +981,7 @@ struct ggml_tensor * forward_batch_wo_cache(
// norm
{
// cur shape [n_embd,N*n_batch,1,1]
cur = ggml_rms_norm(ctx0, inpL);
cur = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
assert_shape_2d(cur, n_embd, N*n_batch);

// cur = attention_norm*cur
Expand Down Expand Up @@ -1085,7 +1087,7 @@ struct ggml_tensor * forward_batch_wo_cache(
// norm
{
// cur shape [n_embd,N*n_batch,1,1]
cur = ggml_rms_norm(ctx0, inpFF);
cur = ggml_rms_norm(ctx0, inpFF, rms_norm_eps);
assert_shape_2d(cur, n_embd, N*n_batch);

// cur = ffn_norm*cur
Expand Down Expand Up @@ -1138,7 +1140,7 @@ struct ggml_tensor * forward_batch_wo_cache(
{

// inpL shape [n_embd,N*n_batch,1,1]
inpL = ggml_rms_norm(ctx0, inpL);
inpL = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
assert_shape_2d(inpL, n_embd, N*n_batch);

// inpL = norm*inpL
Expand Down Expand Up @@ -1203,7 +1205,7 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn(

// norm
{
cur = ggml_rms_norm(ctx0, inpL);
cur = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
assert_shape_2d(cur, n_embd, N*n_batch);

// cur = attention_norm*cur
Expand Down Expand Up @@ -1267,7 +1269,7 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn(
{
// norm
{
cur = ggml_rms_norm(ctx0, inpFF);
cur = ggml_rms_norm(ctx0, inpFF, rms_norm_eps);
assert_shape_2d(cur, n_embd, N*n_batch);

// cur = ffn_norm*cur
Expand Down Expand Up @@ -1311,7 +1313,7 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn(
// norm
{

inpL = ggml_rms_norm(ctx0, inpL);
inpL = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
assert_shape_2d(inpL, n_embd, N*n_batch);

// inpL = norm*inpL
Expand Down Expand Up @@ -1603,7 +1605,7 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn_train(
struct my_llama_layer & layer = model->layers[il];
// tensors with values necessary for backward pass are in persistent buf(-1)
// other tensors with buf(0) and buf(1) are only temporary needed, and their memory reused after layer is completed.
use_buf(-1); struct ggml_tensor * t02 = expand(gf, ggml_rms_norm (ctx0, cur)); assert_shape_2d(t02, n_embd, N*n_batch);
use_buf(-1); struct ggml_tensor * t02 = expand(gf, ggml_rms_norm (ctx0, cur, rms_norm_eps)); assert_shape_2d(t02, n_embd, N*n_batch);
use_buf( 0); struct ggml_tensor * t03 = expand(gf, ggml_repeat (ctx0, layer.attention_norm, t02)); assert_shape_2d(t03, n_embd, N*n_batch);
use_buf(-1); struct ggml_tensor * t04 = expand(gf, ggml_mul (ctx0, t02, t03)); assert_shape_2d(t04, n_embd, N*n_batch);
use_buf(-1); struct ggml_tensor * t05 = expand(gf, ggml_mul_mat (ctx0, layer.wq, t04)); assert_shape_2d(t05, n_embd, N*n_batch);
Expand All @@ -1623,7 +1625,7 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn_train(
use_buf(-1); struct ggml_tensor * t19 = expand(gf, ggml_reshape_2d (ctx0, t18, n_embd, N*n_batch)); assert_shape_2d(t19, n_embd, N*n_batch);
use_buf( 0); struct ggml_tensor * t20 = expand(gf, ggml_mul_mat (ctx0, layer.wo, t19)); assert_shape_2d(t20, n_embd, N*n_batch);
use_buf(-1); struct ggml_tensor * t21 = expand(gf, ggml_add (ctx0, t20, cur)); assert_shape_2d(t21, n_embd, N*n_batch);
use_buf(-1); struct ggml_tensor * t22 = expand(gf, ggml_rms_norm (ctx0, t21)); assert_shape_2d(t22, n_embd, N*n_batch);
use_buf(-1); struct ggml_tensor * t22 = expand(gf, ggml_rms_norm (ctx0, t21, rms_norm_eps)); assert_shape_2d(t22, n_embd, N*n_batch);
use_buf( 0); struct ggml_tensor * t23 = expand(gf, ggml_repeat (ctx0, layer.ffn_norm, t22)); assert_shape_2d(t23, n_embd, N*n_batch);
use_buf(-1); struct ggml_tensor * t24 = expand(gf, ggml_mul (ctx0, t23, t22)); assert_shape_2d(t24, n_embd, N*n_batch);
use_buf(-1); struct ggml_tensor * t25 = expand(gf, ggml_mul_mat (ctx0, layer.w3, t24)); assert_shape_2d(t25, n_ff, N*n_batch);
Expand Down Expand Up @@ -1666,7 +1668,7 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn_train(
}
clr_buf(0);
use_buf(0);
struct ggml_tensor * t31 = expand(gf, ggml_rms_norm (ctx0, cur)); assert_shape_2d(t31, n_embd, N*n_batch);
struct ggml_tensor * t31 = expand(gf, ggml_rms_norm (ctx0, cur, rms_norm_eps)); assert_shape_2d(t31, n_embd, N*n_batch);
struct ggml_tensor * t32 = expand(gf, ggml_repeat (ctx0, model->norm, t31)); assert_shape_2d(t32, n_embd, N*n_batch);
struct ggml_tensor * t33 = expand(gf, ggml_mul (ctx0, t32, t31)); assert_shape_2d(t33, n_embd, N*n_batch);
use_buf(-1);
Expand Down
13 changes: 7 additions & 6 deletions ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -332,12 +332,10 @@ static __global__ void norm_f32(const float * x, float * dst, const int ncols) {
}
}

static __global__ void rms_norm_f32(const float * x, float * dst, const int ncols) {
static __global__ void rms_norm_f32(const float * x, float * dst, const int ncols, const float eps) {
const int row = blockIdx.x*blockDim.y + threadIdx.y;
const int tid = threadIdx.x;

const float eps = 1e-6f;

float tmp = 0.0f; // partial sum for thread in warp

for (int col = tid; col < ncols; col += WARP_SIZE) {
Expand Down Expand Up @@ -2122,10 +2120,10 @@ static void norm_f32_cuda(const float * x, float * dst, const int ncols, const i
norm_f32<<<nrows, block_dims, 0, stream>>>(x, dst, ncols);
}

static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) {
GGML_ASSERT(ncols % WARP_SIZE == 0);
const dim3 block_dims(WARP_SIZE, 1, 1);
rms_norm_f32<<<nrows, block_dims, 0, stream>>>(x, dst, ncols);
rms_norm_f32<<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
}

static void quantize_row_q8_1_cuda(const float * x, void * vy, const int ndata, const int k, cudaStream_t stream) {
Expand Down Expand Up @@ -2876,8 +2874,11 @@ inline void ggml_cuda_op_rms_norm(
const int64_t ne00 = src0->ne[0];
const int64_t i01_diff = i01_high - i01_low;

float eps;
memcpy(&eps, dst->op_params, sizeof(float));

// compute
rms_norm_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, cudaStream_main);
rms_norm_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, eps, cudaStream_main);

(void) src1;
(void) dst;
Expand Down
3 changes: 2 additions & 1 deletion ggml-metal.m
Original file line number Diff line number Diff line change
Expand Up @@ -812,7 +812,8 @@ void ggml_metal_graph_compute(
encoder = [command_buffer computeCommandEncoder];
}

const float eps = 1e-6f;
float eps;
memcpy(&eps, dst->op_params, sizeof(float));

const int nth = 512;

Expand Down
16 changes: 10 additions & 6 deletions ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -5781,6 +5781,7 @@ struct ggml_tensor * ggml_norm_inplace(
static struct ggml_tensor * ggml_rms_norm_impl(
struct ggml_context * ctx,
struct ggml_tensor * a,
float eps,
bool inplace) {
bool is_node = false;

Expand All @@ -5790,7 +5791,7 @@ static struct ggml_tensor * ggml_rms_norm_impl(

struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);

// TODO: maybe store epsilon here?
ggml_set_op_params(result, &eps, sizeof(eps));

result->op = GGML_OP_RMS_NORM;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
Expand All @@ -5801,14 +5802,16 @@ static struct ggml_tensor * ggml_rms_norm_impl(

struct ggml_tensor * ggml_rms_norm(
struct ggml_context * ctx,
struct ggml_tensor * a) {
return ggml_rms_norm_impl(ctx, a, false);
struct ggml_tensor * a,
float eps) {
return ggml_rms_norm_impl(ctx, a, eps, false);
}

struct ggml_tensor * ggml_rms_norm_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a) {
return ggml_rms_norm_impl(ctx, a, true);
struct ggml_tensor * a,
float eps) {
return ggml_rms_norm_impl(ctx, a, eps, true);
}

struct ggml_tensor * ggml_rms_norm_back(
Expand Down Expand Up @@ -10131,7 +10134,8 @@ static void ggml_compute_forward_rms_norm_f32(

GGML_TENSOR_UNARY_OP_LOCALS;

const float eps = 1e-6f; // TODO: make this a parameter
float eps;
memcpy(&eps, dst->op_params, sizeof(float));

// TODO: optimize
for (int64_t i03 = 0; i03 < ne03; i03++) {
Expand Down
Loading

0 comments on commit 41c6741

Please sign in to comment.