Skip to content

Commit

Permalink
Added 'maximize' option to Adadelta.
Browse files Browse the repository at this point in the history
  • Loading branch information
NiklasGustafsson committed Nov 30, 2022
1 parent 03ceea3 commit 6401fb7
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 23 deletions.
49 changes: 29 additions & 20 deletions src/TorchSharp/Optimizers/Adadelta.cs
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@ public static partial class optim
/// <param name="rho">Coefficient used for computing a running average of squared gradients (default: 0.9)</param>
/// <param name="eps">Term added to the denominator to improve numerical stability, i.e. avoid division-by-zero (default: 1e-6)</param>
/// <param name="weight_decay">Weight decay (L2 penalty) (default: 0)</param>
/// <returns></returns>
public static Adadelta Adadelta(IEnumerable<Parameter> parameters, double lr = 1.0, double rho = 0.9, double eps = 1e-6, double weight_decay = 0)
/// <param name="maximize">Maximize the params based on the objective, instead of minimizing</param>
public static Adadelta Adadelta(IEnumerable<Parameter> parameters, double lr = 1.0, double rho = 0.9, double eps = 1e-6, double weight_decay = 0, bool maximize = false)
{
return new Adadelta(parameters, lr, rho, eps, weight_decay);
return new Adadelta(parameters, lr, rho, eps, weight_decay, maximize);
}

/// <summary>
Expand All @@ -42,10 +42,10 @@ public static Adadelta Adadelta(IEnumerable<Parameter> parameters, double lr = 1
/// <param name="rho">Coefficient used for computing a running average of squared gradients (default: 0.9)</param>
/// <param name="eps">Term added to the denominator to improve numerical stability, i.e. avoid division-by-zero (default: 1e-6)</param>
/// <param name="weight_decay">Weight decay (L2 penalty) (default: 0)</param>
/// <returns></returns>
public static Adadelta Adadelta(IEnumerable<(string name, Parameter parameter)> parameters, double lr = 1.0, double rho = 0.9, double eps = 1e-6, double weight_decay = 0)
/// <param name="maximize">Maximize the params based on the objective, instead of minimizing</param>
public static Adadelta Adadelta(IEnumerable<(string name, Parameter parameter)> parameters, double lr = 1.0, double rho = 0.9, double eps = 1e-6, double weight_decay = 0, bool maximize = false)
{
return new Adadelta(parameters.Select(np => np.parameter), lr, rho, eps, weight_decay);
return new Adadelta(parameters.Select(np => np.parameter), lr, rho, eps, weight_decay, maximize);
}

/// <summary>
Expand All @@ -59,10 +59,10 @@ public static Adadelta Adadelta(IEnumerable<(string name, Parameter parameter)>
/// <param name="rho">Coefficient used for computing a running average of squared gradients (default: 0.9)</param>
/// <param name="eps">Term added to the denominator to improve numerical stability, i.e. avoid division-by-zero (default: 1e-6)</param>
/// <param name="weight_decay">Weight decay (L2 penalty) (default: 0)</param>
/// <returns></returns>
public static Adadelta Adadelta(IEnumerable<Adadelta.ParamGroup> parameters, double lr = 1.0, double rho = 0.9, double eps = 1e-6, double weight_decay = 0)
/// <param name="maximize">Maximize the params based on the objective, instead of minimizing</param>
public static Adadelta Adadelta(IEnumerable<Adadelta.ParamGroup> parameters, double lr = 1.0, double rho = 0.9, double eps = 1e-6, double weight_decay = 0, bool maximize = false)
{
return new Adadelta(parameters, lr, rho, eps, weight_decay);
return new Adadelta(parameters, lr, rho, eps, weight_decay, maximize);
}
}
}
Expand All @@ -79,8 +79,9 @@ public class Adadelta : OptimizerHelper
/// <param name="rho">Coefficient used for computing a running average of squared gradients (default: 0.9)</param>
/// <param name="eps">Term added to the denominator to improve numerical stability, i.e. avoid division-by-zero (default: 1e-6)</param>
/// <param name="weight_decay">Weight decay (L2 penalty) (default: 0)</param>
public Adadelta(IEnumerable<Parameter> parameters, double lr, double rho = 0.9, double eps = 1e-6, double weight_decay = 0)
: this(new ParamGroup[] { new ParamGroup { Parameters = parameters } }, lr, rho, eps, weight_decay)
/// <param name="maximize">Maximize the params based on the objective, instead of minimizing</param>
public Adadelta(IEnumerable<Parameter> parameters, double lr, double rho = 0.9, double eps = 1e-6, double weight_decay = 0, bool maximize = false)
: this(new ParamGroup[] { new ParamGroup { Parameters = parameters } }, lr, rho, eps, weight_decay, maximize)
{
}

Expand All @@ -92,7 +93,8 @@ public Adadelta(IEnumerable<Parameter> parameters, double lr, double rho = 0.9,
/// <param name="rho">Coefficient used for computing a running average of squared gradients (default: 0.9)</param>
/// <param name="eps">Term added to the denominator to improve numerical stability, i.e. avoid division-by-zero (default: 1e-6)</param>
/// <param name="weight_decay">Weight decay (L2 penalty) (default: 0)</param>
public Adadelta(IEnumerable<ParamGroup> parameters, double lr = 1.0, double rho = 0.9, double eps = 1e-6, double weight_decay = 0)
/// <param name="maximize">Maximize the params based on the objective, instead of minimizing</param>
public Adadelta(IEnumerable<ParamGroup> parameters, double lr = 1.0, double rho = 0.9, double eps = 1e-6, double weight_decay = 0, bool maximize = false)
{
if (lr < 0.0) throw new ArgumentException($"Invalid learning rate: {lr}");
if (rho < 0.0 || rho > 1.0) throw new ArgumentException($"Invalid rho value: {rho}");
Expand All @@ -104,6 +106,7 @@ public Adadelta(IEnumerable<ParamGroup> parameters, double lr = 1.0, double rho
InitialLearningRate = lr,
rho = rho,
eps = eps,
maximize = maximize,
weight_decay = weight_decay
};

Expand All @@ -116,23 +119,24 @@ public Adadelta(IEnumerable<ParamGroup> parameters, double lr = 1.0, double rho
}

/// <summary>
/// Performs a single optimization step (parameter update).
/// </summary>
/// <param name="closure">A closure that reevaluates the model and returns the loss. Optional for most optimizers.</param>
/// <returns></returns>
public override Tensor step(Func<Tensor> closure = null)
/// Performs a single optimization step (parameter update).
/// </summary>
/// <param name="closure">A closure that reevaluates the model and returns the loss. Optional for most optimizers.</param>
/// <returns></returns>
public override Tensor step(Func<Tensor> closure = null)
{
return _step<ParamGroup>(group => {
var options = group.Options as Options;
var rho = options.rho.Value;
var eps = options.eps.Value;
var weight_decay = options.weight_decay.Value;
var maximize = options.maximize.Value;
var lr = options.LearningRate.Value;
foreach (var param in group.Parameters) {
var grad = param.grad();
var grad = (maximize) ? -param.grad() : param.grad();
if (grad is null) continue;
Expand Down Expand Up @@ -253,6 +257,7 @@ public override void add_param_group(Modules.ParamGroup param_group)
if (!opt.rho.HasValue) opt.rho = def.rho;
if (!opt.eps.HasValue) opt.eps = def.eps;
if (!opt.weight_decay.HasValue) opt.weight_decay = def.weight_decay;
if (!opt.maximize.HasValue) opt.maximize = def.maximize;

opt.InitialLearningRate = opt.LearningRate.Value;

Expand All @@ -272,6 +277,7 @@ public class Options : OptimizerOptions
public double? rho;
public double? eps;
public double? weight_decay;
public bool? maximize;

/// <summary>
/// Load optimizer options (param-group hyperparameters) from another optimizer.
Expand All @@ -284,6 +290,7 @@ public override void LoadStateDict(OptimizerOptions source)
rho = opts.rho;
eps = opts.eps;
weight_decay = opts.weight_decay;
maximize = opts.maximize;
}

/// <summary>
Expand All @@ -296,6 +303,7 @@ public override void LoadStateDict(BinaryReader reader)
rho = reader.ReadDouble();
eps = reader.ReadDouble();
weight_decay = reader.ReadDouble();
maximize = reader.ReadBoolean();
}

/// <summary>
Expand All @@ -308,6 +316,7 @@ public override void SaveStateDict(BinaryWriter writer)
writer.Write(rho.Value);
writer.Write(eps.Value);
writer.Write(weight_decay.Value);
writer.Write(maximize.Value);
}
}

Expand All @@ -317,8 +326,8 @@ public ParamGroup() { }

public ParamGroup(IEnumerable<Parameter> parameters, Options options) : base(parameters, options) { }

public ParamGroup(IEnumerable<Parameter> parameters, double lr = 1.0, double rho = 0.9, double eps = 1e-6, double weight_decay = 0)
: base(parameters, new Adadelta.Options { LearningRate = lr, rho = rho, eps = eps, weight_decay = weight_decay })
public ParamGroup(IEnumerable<Parameter> parameters, double lr = 1.0, double rho = 0.9, double eps = 1e-6, double weight_decay = 0, bool maximize = false)
: base(parameters, new Adadelta.Options { LearningRate = lr, rho = rho, eps = eps, weight_decay = weight_decay, maximize = maximize })
{
}
}
Expand Down
6 changes: 3 additions & 3 deletions src/TorchSharp/Optimizers/Adam.cs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ public static partial class optim
/// <param name="eps">Term added to the denominator to improve numerical stability (default: 1e-8)</param>
/// <param name="weight_decay">Weight decay (L2 penalty) (default: 0)</param>
/// <param name="amsgrad">Whether to use the AMSGrad variant of this algorithm. (default: False)</param>
/// <param name="maximize"></param>
/// <param name="maximize">Maximize the params based on the objective, instead of minimizing</param>
/// <returns></returns>
public static Adam Adam(IEnumerable<Parameter> parameters, double lr = 1e-3, double beta1 = 0.9, double beta2 = 0.99, double eps = 1e-8, double weight_decay = 0, bool amsgrad = false, bool maximize = false)
{
Expand All @@ -47,7 +47,7 @@ public static Adam Adam(IEnumerable<Parameter> parameters, double lr = 1e-3, dou
/// <param name="eps">Term added to the denominator to improve numerical stability (default: 1e-8)</param>
/// <param name="weight_decay">Weight decay (L2 penalty) (default: 0)</param>
/// <param name="amsgrad">Whether to use the AMSGrad variant of this algorithm. (default: False)</param>
/// <param name="maximize"></param>
/// <param name="maximize">Maximize the params based on the objective, instead of minimizing</param>
/// <returns></returns>
public static Adam Adam(IEnumerable<(string name, Parameter parameter)> parameters, double lr = 1e-3, double beta1 = 0.9, double beta2 = 0.99, double eps = 1e-8, double weight_decay = 0, bool amsgrad = false, bool maximize = false)
{
Expand All @@ -67,7 +67,7 @@ public static Adam Adam(IEnumerable<(string name, Parameter parameter)> paramete
/// <param name="eps">Term added to the denominator to improve numerical stability (default: 1e-8)</param>
/// <param name="weight_decay">Weight decay (L2 penalty) (default: 0)</param>
/// <param name="amsgrad">Whether to use the AMSGrad variant of this algorithm. (default: False)</param>
/// <param name="maximize"></param>
/// <param name="maximize">Maximize the params based on the objective, instead of minimizing</param>
/// <returns></returns>
public static Adam Adam(IEnumerable<Adam.ParamGroup> parameters, double lr = 1e-3, double beta1 = 0.9, double beta2 = 0.99, double eps = 1e-8, double weight_decay = 0, bool amsgrad = false, bool maximize = false)
{
Expand Down
17 changes: 17 additions & 0 deletions test/TorchSharpTest/TestTraining.cs
Original file line number Diff line number Diff line change
Expand Up @@ -596,6 +596,23 @@ public void TestTrainingAdadelta()
LossIsClose(74.754f, loss);
}

[Fact]
public void TestTrainingAdadeltaMax()
{
var gen = new Generator(4711);
CreateLinearLayers(gen, out var lin1, out var lin2);
CreateDataAndLabels(gen, out var x, out var y);

var seq = Sequential(("lin1", lin1), ("relu1", ReLU()), ("lin2", lin2));

double learning_rate = 1.0f;
var optimizer = torch.optim.Adadelta(seq.parameters(), learning_rate, maximize:true);

var loss = TrainLoop(seq, x, y, optimizer, maximize: true);

LossIsClose(74.754f, -loss);
}

[Fact]
public void TestTrainingAdadeltaWD()
{
Expand Down

0 comments on commit 6401fb7

Please sign in to comment.