Skip to content

Commit

Permalink
drop Net inputs + Forward with bottoms
Browse files Browse the repository at this point in the history
  • Loading branch information
shelhamer committed Oct 17, 2015
1 parent 9698cc8 commit 87cdf87
Show file tree
Hide file tree
Showing 9 changed files with 72 additions and 253 deletions.
25 changes: 12 additions & 13 deletions examples/cpp_classification/classification.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,14 +59,13 @@ Classifier::Classifier(const string& model_file,
net_.reset(new Net<float>(model_file, TEST));
net_->CopyTrainedLayersFrom(trained_file);

CHECK_EQ(net_->num_inputs(), 1) << "Network should have exactly one input.";
CHECK_EQ(net_->num_outputs(), 1) << "Network should have exactly one output.";

Blob<float>* input_layer = net_->input_blobs()[0];
num_channels_ = input_layer->channels();
shared_ptr<Blob<float> > input_blob = net_->blob_by_name("data");
num_channels_ = input_blob->channels();
CHECK(num_channels_ == 3 || num_channels_ == 1)
<< "Input layer should have 1 or 3 channels.";
input_geometry_ = cv::Size(input_layer->width(), input_layer->height());
input_geometry_ = cv::Size(input_blob->width(), input_blob->height());

/* Load the binaryproto mean file. */
SetMean(mean_file);
Expand Down Expand Up @@ -148,8 +147,8 @@ void Classifier::SetMean(const string& mean_file) {
}

std::vector<float> Classifier::Predict(const cv::Mat& img) {
Blob<float>* input_layer = net_->input_blobs()[0];
input_layer->Reshape(1, num_channels_,
shared_ptr<Blob<float> > input_blob = net_->blob_by_name("data");
input_blob->Reshape(1, num_channels_,
input_geometry_.height, input_geometry_.width);
/* Forward dimension change to all layers. */
net_->Reshape();
Expand All @@ -159,7 +158,7 @@ std::vector<float> Classifier::Predict(const cv::Mat& img) {

Preprocess(img, &input_channels);

net_->ForwardPrefilled();
net_->Forward();

/* Copy the output layer to a std::vector */
Blob<float>* output_layer = net_->output_blobs()[0];
Expand All @@ -174,12 +173,12 @@ std::vector<float> Classifier::Predict(const cv::Mat& img) {
* operation will write the separate channels directly to the input
* layer. */
void Classifier::WrapInputLayer(std::vector<cv::Mat>* input_channels) {
Blob<float>* input_layer = net_->input_blobs()[0];
shared_ptr<Blob<float> > input_blob = net_->blob_by_name("data");

int width = input_layer->width();
int height = input_layer->height();
float* input_data = input_layer->mutable_cpu_data();
for (int i = 0; i < input_layer->channels(); ++i) {
int width = input_blob->width();
int height = input_blob->height();
float* input_data = input_blob->mutable_cpu_data();
for (int i = 0; i < input_blob->channels(); ++i) {
cv::Mat channel(height, width, CV_32FC1, input_data);
input_channels->push_back(channel);
input_data += width * height;
Expand Down Expand Up @@ -222,7 +221,7 @@ void Classifier::Preprocess(const cv::Mat& img,
cv::split(sample_normalized, *input_channels);

CHECK(reinterpret_cast<float*>(input_channels->at(0).data)
== net_->input_blobs()[0]->cpu_data())
== net_->blob_by_name("data")->cpu_data())
<< "Input channels are not wrapping the input layer of the network.";
}

Expand Down
34 changes: 7 additions & 27 deletions include/caffe/net.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,10 @@ class Net {
void Init(const NetParameter& param);

/**
* @brief Run Forward with the input Blob%s already fed separately.
* @brief Run Forward and return the result.
*
* You can get the input blobs using input_blobs().
*/
const vector<Blob<Dtype>*>& ForwardPrefilled(Dtype* loss = NULL);
const vector<Blob<Dtype>*>& Forward(Dtype* loss = NULL);

/**
* The From and To variants of Forward and Backward operate on the
Expand All @@ -49,14 +48,6 @@ class Net {
Dtype ForwardFromTo(int start, int end);
Dtype ForwardFrom(int start);
Dtype ForwardTo(int end);
/// @brief Run forward using a set of bottom blobs, and return the result.
const vector<Blob<Dtype>*>& Forward(const vector<Blob<Dtype>* > & bottom,
Dtype* loss = NULL);
/**
* @brief Run forward using a serialized BlobProtoVector and return the
* result as a serialized BlobProtoVector
*/
string Forward(const string& input_blob_protos, Dtype* loss = NULL);

/**
* @brief Zeroes out the diffs of all net parameters.
Expand All @@ -82,9 +73,9 @@ class Net {
*/
void Reshape();

Dtype ForwardBackward(const vector<Blob<Dtype>* > & bottom) {
Dtype ForwardBackward() {
Dtype loss;
Forward(bottom, &loss);
Forward(&loss);
Backward();
return loss;
}
Expand Down Expand Up @@ -179,18 +170,11 @@ class Net {
return param_names_index_;
}
inline const vector<int>& param_owners() const { return param_owners_; }
/// @brief Input and output blob numbers
inline int num_inputs() const { return net_input_blobs_.size(); }
/// @brief output blob number
inline int num_outputs() const { return net_output_blobs_.size(); }
inline const vector<Blob<Dtype>*>& input_blobs() const {
return net_input_blobs_;
}
inline const vector<Blob<Dtype>*>& output_blobs() const {
return net_output_blobs_;
}
inline const vector<int>& input_blob_indices() const {
return net_input_blob_indices_;
}
inline const vector<int>& output_blob_indices() const {
return net_output_blob_indices_;
}
Expand All @@ -214,7 +198,7 @@ class Net {

protected:
// Helpers for Init.
/// @brief Append a new input or top blob to the net.
/// @brief Append a new top blob to the net.
void AppendTop(const NetParameter& param, const int layer_id,
const int top_id, set<string>* available_blobs,
map<string, int>* blob_name_to_idx);
Expand All @@ -226,8 +210,6 @@ class Net {
void AppendParam(const NetParameter& param, const int layer_id,
const int param_id);

/// @brief Helper for displaying debug info in Forward about input Blobs.
void InputDebugInfo(const int layer_id);
/// @brief Helper for displaying debug info in Forward.
void ForwardDebugInfo(const int layer_id);
/// @brief Helper for displaying debug info in Backward.
Expand Down Expand Up @@ -266,10 +248,8 @@ class Net {
vector<string> param_display_names_;
vector<pair<int, int> > param_layer_indices_;
map<string, int> param_names_index_;
/// blob indices for the input and the output of the net
vector<int> net_input_blob_indices_;
/// blob indices for the output of the net
vector<int> net_output_blob_indices_;
vector<Blob<Dtype>*> net_input_blobs_;
vector<Blob<Dtype>*> net_output_blobs_;
/// The parameters in the network.
vector<shared_ptr<Blob<Dtype> > > params_;
Expand Down
105 changes: 10 additions & 95 deletions src/caffe/net.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,22 +56,7 @@ void Net<Dtype>::Init(const NetParameter& in_param) {
name_ = param.name();
map<string, int> blob_name_to_idx;
set<string> available_blobs;
CHECK(param.input_dim_size() == 0 || param.input_shape_size() == 0)
<< "Must specify either input_shape OR deprecated input_dim, not both.";
if (param.input_dim_size() > 0) {
// Deprecated 4D dimensions.
CHECK_EQ(param.input_size() * 4, param.input_dim_size())
<< "Incorrect input blob dimension specifications.";
} else {
CHECK_EQ(param.input_size(), param.input_shape_size())
<< "Exactly one input_shape must be specified per input.";
}
memory_used_ = 0;
// set the input blobs
for (int input_id = 0; input_id < param.input_size(); ++input_id) {
const int layer_id = -1; // inputs have fake layer ID -1
AppendTop(param, layer_id, input_id, &available_blobs, &blob_name_to_idx);
}
// For each layer, set up its input and output
bottom_vecs_.resize(param.layer_size());
top_vecs_.resize(param.layer_size());
Expand Down Expand Up @@ -379,19 +364,17 @@ bool Net<Dtype>::StateMeetsRule(const NetState& state,
return true;
}

// Helper for Net::Init: add a new input or top blob to the net. (Inputs have
// layer_id == -1, tops have layer_id >= 0.)
// Helper for Net::Init: add a new top blob to the net.
template <typename Dtype>
void Net<Dtype>::AppendTop(const NetParameter& param, const int layer_id,
const int top_id, set<string>* available_blobs,
map<string, int>* blob_name_to_idx) {
shared_ptr<LayerParameter> layer_param((layer_id >= 0) ?
(new LayerParameter(param.layer(layer_id))) : NULL);
const string& blob_name = layer_param ?
(layer_param->top_size() > top_id ?
layer_param->top(top_id) : "(automatic)") : param.input(top_id);
shared_ptr<LayerParameter> layer_param(
new LayerParameter(param.layer(layer_id)));
const string& blob_name = (layer_param->top_size() > top_id) ?
layer_param->top(top_id) : "(automatic)";
// Check if we are doing in-place computation
if (blob_name_to_idx && layer_param && layer_param->bottom_size() > top_id &&
if (blob_name_to_idx && layer_param->bottom_size() > top_id &&
blob_name == layer_param->bottom(top_id)) {
// In-place computation
LOG_IF(INFO, Caffe::root_solver())
Expand All @@ -407,34 +390,16 @@ void Net<Dtype>::AppendTop(const NetParameter& param, const int layer_id,
} else {
// Normal output.
if (Caffe::root_solver()) {
if (layer_param) {
LOG(INFO) << layer_param->name() << " -> " << blob_name;
} else {
LOG(INFO) << "Input " << top_id << " -> " << blob_name;
}
LOG(INFO) << layer_param->name() << " -> " << blob_name;
}
shared_ptr<Blob<Dtype> > blob_pointer(new Blob<Dtype>());
const int blob_id = blobs_.size();
blobs_.push_back(blob_pointer);
blob_names_.push_back(blob_name);
blob_need_backward_.push_back(false);
if (blob_name_to_idx) { (*blob_name_to_idx)[blob_name] = blob_id; }
if (layer_id == -1) {
// Set the (explicitly specified) dimensions of the input blob.
if (param.input_dim_size() > 0) {
blob_pointer->Reshape(param.input_dim(top_id * 4),
param.input_dim(top_id * 4 + 1),
param.input_dim(top_id * 4 + 2),
param.input_dim(top_id * 4 + 3));
} else {
blob_pointer->Reshape(param.input_shape(top_id));
}
net_input_blob_indices_.push_back(blob_id);
net_input_blobs_.push_back(blob_pointer.get());
} else {
top_id_vecs_[layer_id].push_back(blob_id);
top_vecs_[layer_id].push_back(blob_pointer.get());
}
top_id_vecs_[layer_id].push_back(blob_id);
top_vecs_[layer_id].push_back(blob_pointer.get());
}
if (available_blobs) { available_blobs->insert(blob_name); }
}
Expand Down Expand Up @@ -566,11 +531,6 @@ Dtype Net<Dtype>::ForwardFromTo(int start, int end) {
CHECK_GE(start, 0);
CHECK_LT(end, layers_.size());
Dtype loss = 0;
if (debug_info_) {
for (int i = 0; i < net_input_blobs_.size(); ++i) {
InputDebugInfo(i);
}
}
for (int i = start; i <= end; ++i) {
// LOG(ERROR) << "Forwarding " << layer_names_[i];
Dtype layer_loss = layers_[i]->Forward(bottom_vecs_[i], top_vecs_[i]);
Expand All @@ -591,7 +551,7 @@ Dtype Net<Dtype>::ForwardTo(int end) {
}

template <typename Dtype>
const vector<Blob<Dtype>*>& Net<Dtype>::ForwardPrefilled(Dtype* loss) {
const vector<Blob<Dtype>*>& Net<Dtype>::Forward(Dtype* loss) {
if (loss != NULL) {
*loss = ForwardFromTo(0, layers_.size() - 1);
} else {
Expand All @@ -600,37 +560,6 @@ const vector<Blob<Dtype>*>& Net<Dtype>::ForwardPrefilled(Dtype* loss) {
return net_output_blobs_;
}

template <typename Dtype>
const vector<Blob<Dtype>*>& Net<Dtype>::Forward(
const vector<Blob<Dtype>*> & bottom, Dtype* loss) {
// Copy bottom to internal bottom
for (int i = 0; i < bottom.size(); ++i) {
net_input_blobs_[i]->CopyFrom(*bottom[i]);
}
return ForwardPrefilled(loss);
}

template <typename Dtype>
string Net<Dtype>::Forward(const string& input_blob_protos, Dtype* loss) {
BlobProtoVector blob_proto_vec;
if (net_input_blobs_.size()) {
blob_proto_vec.ParseFromString(input_blob_protos);
CHECK_EQ(blob_proto_vec.blobs_size(), net_input_blobs_.size())
<< "Incorrect input size.";
for (int i = 0; i < blob_proto_vec.blobs_size(); ++i) {
net_input_blobs_[i]->FromProto(blob_proto_vec.blobs(i));
}
}
ForwardPrefilled(loss);
blob_proto_vec.Clear();
for (int i = 0; i < net_output_blobs_.size(); ++i) {
net_output_blobs_[i]->ToProto(blob_proto_vec.add_blobs());
}
string output;
blob_proto_vec.SerializeToString(&output);
return output;
}

template <typename Dtype>
void Net<Dtype>::BackwardFromTo(int start, int end) {
CHECK_GE(end, 0);
Expand All @@ -644,17 +573,6 @@ void Net<Dtype>::BackwardFromTo(int start, int end) {
}
}

template <typename Dtype>
void Net<Dtype>::InputDebugInfo(const int input_id) {
const Blob<Dtype>& blob = *net_input_blobs_[input_id];
const string& blob_name = blob_names_[net_input_blob_indices_[input_id]];
const Dtype data_abs_val_mean = blob.asum_data() / blob.count();
LOG_IF(INFO, Caffe::root_solver())
<< " [Forward] "
<< "Input " << blob_name << " data: " << data_abs_val_mean;
}

template <typename Dtype>
void Net<Dtype>::ForwardDebugInfo(const int layer_id) {
for (int top_id = 0; top_id < top_vecs_[layer_id].size(); ++top_id) {
const Blob<Dtype>& blob = *top_vecs_[layer_id][top_id];
Expand Down Expand Up @@ -912,9 +830,6 @@ void Net<Dtype>::ToProto(NetParameter* param, bool write_diff) const {
param->Clear();
param->set_name(name_);
// Add bottom and top
for (int i = 0; i < net_input_blob_indices_.size(); ++i) {
param->add_input(blob_names_[net_input_blob_indices_[i]]);
}
DLOG(INFO) << "Serializing " << layers_.size() << " layers";
for (int i = 0; i < layers_.size(); ++i) {
LayerParameter* layer_param = param->add_layer();
Expand Down
8 changes: 3 additions & 5 deletions src/caffe/solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,6 @@ void Solver<Dtype>::InitTestNets() {

template <typename Dtype>
void Solver<Dtype>::Step(int iters) {
vector<Blob<Dtype>*> bottom_vec;
const int start_iter = iter_;
const int stop_iter = iter_ + iters;
int average_loss = this->param_.average_loss();
Expand Down Expand Up @@ -219,7 +218,7 @@ void Solver<Dtype>::Step(int iters) {
// accumulate the loss and gradient
Dtype loss = 0;
for (int i = 0; i < param_.iter_size(); ++i) {
loss += net_->ForwardBackward(bottom_vec);
loss += net_->ForwardBackward();
}
loss /= param_.iter_size();
// average the loss across iterations for smoothed reporting
Expand Down Expand Up @@ -316,7 +315,7 @@ void Solver<Dtype>::Solve(const char* resume_file) {
// display the loss, which is computed in the forward pass.
if (param_.display() && iter_ % param_.display() == 0) {
Dtype loss;
net_->ForwardPrefilled(&loss);
net_->Forward(&loss);
LOG(INFO) << "Iteration " << iter_ << ", loss = " << loss;
}
if (param_.test_interval() && iter_ % param_.test_interval() == 0) {
Expand All @@ -343,7 +342,6 @@ void Solver<Dtype>::Test(const int test_net_id) {
ShareTrainedLayersWith(net_.get());
vector<Dtype> test_score;
vector<int> test_score_output_id;
vector<Blob<Dtype>*> bottom_vec;
const shared_ptr<Net<Dtype> >& test_net = test_nets_[test_net_id];
Dtype loss = 0;
for (int i = 0; i < param_.test_iter(test_net_id); ++i) {
Expand All @@ -364,7 +362,7 @@ void Solver<Dtype>::Test(const int test_net_id) {

Dtype iter_loss;
const vector<Blob<Dtype>*>& result =
test_net->Forward(bottom_vec, &iter_loss);
test_net->Forward(&iter_loss);
if (param_.test_compute_loss()) {
loss += iter_loss;
}
Expand Down
6 changes: 2 additions & 4 deletions src/caffe/test/test_gradient_based_solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -185,9 +185,8 @@ class GradientBasedSolverTest : public MultiDeviceTest<TypeParam> {
this->InitSolverFromProtoString(proto.str());
if (from_snapshot != NULL) {
this->solver_->Restore(from_snapshot);
vector<Blob<Dtype>*> empty_bottom_vec;
for (int i = 0; i < this->solver_->iter(); ++i) {
this->solver_->net()->Forward(empty_bottom_vec);
this->solver_->net()->Forward();
}
}
if (devices == 1) {
Expand Down Expand Up @@ -231,8 +230,7 @@ class GradientBasedSolverTest : public MultiDeviceTest<TypeParam> {
// Run a forward pass, and manually compute the update values from the
// result.
Net<Dtype>& net = *this->solver_->net();
vector<Blob<Dtype>*> empty_bottom_vec;
net.Forward(empty_bottom_vec);
net.Forward();
ASSERT_TRUE(net.has_blob("data"));
const Blob<Dtype>& data = *net.blob_by_name("data");
ASSERT_TRUE(net.has_blob("targets"));
Expand Down
Loading

0 comments on commit 87cdf87

Please sign in to comment.