From 87cdf8743af203a9af4a0fcbccb784900d6dc020 Mon Sep 17 00:00:00 2001 From: Evan Shelhamer Date: Fri, 16 Oct 2015 21:11:32 -0700 Subject: [PATCH] drop Net inputs + Forward with bottoms --- .../cpp_classification/classification.cpp | 25 ++--- include/caffe/net.hpp | 34 ++---- src/caffe/net.cpp | 105 ++---------------- src/caffe/solver.cpp | 8 +- src/caffe/test/test_gradient_based_solver.cpp | 6 +- src/caffe/test/test_net.cpp | 78 ++++++------- src/caffe/test/test_split_layer.cpp | 61 ---------- tools/caffe.cpp | 5 +- tools/extract_features.cpp | 3 +- 9 files changed, 72 insertions(+), 253 deletions(-) diff --git a/examples/cpp_classification/classification.cpp b/examples/cpp_classification/classification.cpp index de48fb692c8..4bc23ca8ef8 100644 --- a/examples/cpp_classification/classification.cpp +++ b/examples/cpp_classification/classification.cpp @@ -59,14 +59,13 @@ Classifier::Classifier(const string& model_file, net_.reset(new Net(model_file, TEST)); net_->CopyTrainedLayersFrom(trained_file); - CHECK_EQ(net_->num_inputs(), 1) << "Network should have exactly one input."; CHECK_EQ(net_->num_outputs(), 1) << "Network should have exactly one output."; - Blob* input_layer = net_->input_blobs()[0]; - num_channels_ = input_layer->channels(); + shared_ptr > input_blob = net_->blob_by_name("data"); + num_channels_ = input_blob->channels(); CHECK(num_channels_ == 3 || num_channels_ == 1) << "Input layer should have 1 or 3 channels."; - input_geometry_ = cv::Size(input_layer->width(), input_layer->height()); + input_geometry_ = cv::Size(input_blob->width(), input_blob->height()); /* Load the binaryproto mean file. */ SetMean(mean_file); @@ -148,8 +147,8 @@ void Classifier::SetMean(const string& mean_file) { } std::vector Classifier::Predict(const cv::Mat& img) { - Blob* input_layer = net_->input_blobs()[0]; - input_layer->Reshape(1, num_channels_, + shared_ptr > input_blob = net_->blob_by_name("data"); + input_blob->Reshape(1, num_channels_, input_geometry_.height, input_geometry_.width); /* Forward dimension change to all layers. */ net_->Reshape(); @@ -159,7 +158,7 @@ std::vector Classifier::Predict(const cv::Mat& img) { Preprocess(img, &input_channels); - net_->ForwardPrefilled(); + net_->Forward(); /* Copy the output layer to a std::vector */ Blob* output_layer = net_->output_blobs()[0]; @@ -174,12 +173,12 @@ std::vector Classifier::Predict(const cv::Mat& img) { * operation will write the separate channels directly to the input * layer. */ void Classifier::WrapInputLayer(std::vector* input_channels) { - Blob* input_layer = net_->input_blobs()[0]; + shared_ptr > input_blob = net_->blob_by_name("data"); - int width = input_layer->width(); - int height = input_layer->height(); - float* input_data = input_layer->mutable_cpu_data(); - for (int i = 0; i < input_layer->channels(); ++i) { + int width = input_blob->width(); + int height = input_blob->height(); + float* input_data = input_blob->mutable_cpu_data(); + for (int i = 0; i < input_blob->channels(); ++i) { cv::Mat channel(height, width, CV_32FC1, input_data); input_channels->push_back(channel); input_data += width * height; @@ -222,7 +221,7 @@ void Classifier::Preprocess(const cv::Mat& img, cv::split(sample_normalized, *input_channels); CHECK(reinterpret_cast(input_channels->at(0).data) - == net_->input_blobs()[0]->cpu_data()) + == net_->blob_by_name("data")->cpu_data()) << "Input channels are not wrapping the input layer of the network."; } diff --git a/include/caffe/net.hpp b/include/caffe/net.hpp index 1bf07d28d13..33ccec3476d 100644 --- a/include/caffe/net.hpp +++ b/include/caffe/net.hpp @@ -32,11 +32,10 @@ class Net { void Init(const NetParameter& param); /** - * @brief Run Forward with the input Blob%s already fed separately. + * @brief Run Forward and return the result. * - * You can get the input blobs using input_blobs(). */ - const vector*>& ForwardPrefilled(Dtype* loss = NULL); + const vector*>& Forward(Dtype* loss = NULL); /** * The From and To variants of Forward and Backward operate on the @@ -49,14 +48,6 @@ class Net { Dtype ForwardFromTo(int start, int end); Dtype ForwardFrom(int start); Dtype ForwardTo(int end); - /// @brief Run forward using a set of bottom blobs, and return the result. - const vector*>& Forward(const vector* > & bottom, - Dtype* loss = NULL); - /** - * @brief Run forward using a serialized BlobProtoVector and return the - * result as a serialized BlobProtoVector - */ - string Forward(const string& input_blob_protos, Dtype* loss = NULL); /** * @brief Zeroes out the diffs of all net parameters. @@ -82,9 +73,9 @@ class Net { */ void Reshape(); - Dtype ForwardBackward(const vector* > & bottom) { + Dtype ForwardBackward() { Dtype loss; - Forward(bottom, &loss); + Forward(&loss); Backward(); return loss; } @@ -179,18 +170,11 @@ class Net { return param_names_index_; } inline const vector& param_owners() const { return param_owners_; } - /// @brief Input and output blob numbers - inline int num_inputs() const { return net_input_blobs_.size(); } + /// @brief output blob number inline int num_outputs() const { return net_output_blobs_.size(); } - inline const vector*>& input_blobs() const { - return net_input_blobs_; - } inline const vector*>& output_blobs() const { return net_output_blobs_; } - inline const vector& input_blob_indices() const { - return net_input_blob_indices_; - } inline const vector& output_blob_indices() const { return net_output_blob_indices_; } @@ -214,7 +198,7 @@ class Net { protected: // Helpers for Init. - /// @brief Append a new input or top blob to the net. + /// @brief Append a new top blob to the net. void AppendTop(const NetParameter& param, const int layer_id, const int top_id, set* available_blobs, map* blob_name_to_idx); @@ -226,8 +210,6 @@ class Net { void AppendParam(const NetParameter& param, const int layer_id, const int param_id); - /// @brief Helper for displaying debug info in Forward about input Blobs. - void InputDebugInfo(const int layer_id); /// @brief Helper for displaying debug info in Forward. void ForwardDebugInfo(const int layer_id); /// @brief Helper for displaying debug info in Backward. @@ -266,10 +248,8 @@ class Net { vector param_display_names_; vector > param_layer_indices_; map param_names_index_; - /// blob indices for the input and the output of the net - vector net_input_blob_indices_; + /// blob indices for the output of the net vector net_output_blob_indices_; - vector*> net_input_blobs_; vector*> net_output_blobs_; /// The parameters in the network. vector > > params_; diff --git a/src/caffe/net.cpp b/src/caffe/net.cpp index 1ad93e6af5f..e9fcb24d0af 100644 --- a/src/caffe/net.cpp +++ b/src/caffe/net.cpp @@ -56,22 +56,7 @@ void Net::Init(const NetParameter& in_param) { name_ = param.name(); map blob_name_to_idx; set available_blobs; - CHECK(param.input_dim_size() == 0 || param.input_shape_size() == 0) - << "Must specify either input_shape OR deprecated input_dim, not both."; - if (param.input_dim_size() > 0) { - // Deprecated 4D dimensions. - CHECK_EQ(param.input_size() * 4, param.input_dim_size()) - << "Incorrect input blob dimension specifications."; - } else { - CHECK_EQ(param.input_size(), param.input_shape_size()) - << "Exactly one input_shape must be specified per input."; - } memory_used_ = 0; - // set the input blobs - for (int input_id = 0; input_id < param.input_size(); ++input_id) { - const int layer_id = -1; // inputs have fake layer ID -1 - AppendTop(param, layer_id, input_id, &available_blobs, &blob_name_to_idx); - } // For each layer, set up its input and output bottom_vecs_.resize(param.layer_size()); top_vecs_.resize(param.layer_size()); @@ -379,19 +364,17 @@ bool Net::StateMeetsRule(const NetState& state, return true; } -// Helper for Net::Init: add a new input or top blob to the net. (Inputs have -// layer_id == -1, tops have layer_id >= 0.) +// Helper for Net::Init: add a new top blob to the net. template void Net::AppendTop(const NetParameter& param, const int layer_id, const int top_id, set* available_blobs, map* blob_name_to_idx) { - shared_ptr layer_param((layer_id >= 0) ? - (new LayerParameter(param.layer(layer_id))) : NULL); - const string& blob_name = layer_param ? - (layer_param->top_size() > top_id ? - layer_param->top(top_id) : "(automatic)") : param.input(top_id); + shared_ptr layer_param( + new LayerParameter(param.layer(layer_id))); + const string& blob_name = (layer_param->top_size() > top_id) ? + layer_param->top(top_id) : "(automatic)"; // Check if we are doing in-place computation - if (blob_name_to_idx && layer_param && layer_param->bottom_size() > top_id && + if (blob_name_to_idx && layer_param->bottom_size() > top_id && blob_name == layer_param->bottom(top_id)) { // In-place computation LOG_IF(INFO, Caffe::root_solver()) @@ -407,11 +390,7 @@ void Net::AppendTop(const NetParameter& param, const int layer_id, } else { // Normal output. if (Caffe::root_solver()) { - if (layer_param) { - LOG(INFO) << layer_param->name() << " -> " << blob_name; - } else { - LOG(INFO) << "Input " << top_id << " -> " << blob_name; - } + LOG(INFO) << layer_param->name() << " -> " << blob_name; } shared_ptr > blob_pointer(new Blob()); const int blob_id = blobs_.size(); @@ -419,22 +398,8 @@ void Net::AppendTop(const NetParameter& param, const int layer_id, blob_names_.push_back(blob_name); blob_need_backward_.push_back(false); if (blob_name_to_idx) { (*blob_name_to_idx)[blob_name] = blob_id; } - if (layer_id == -1) { - // Set the (explicitly specified) dimensions of the input blob. - if (param.input_dim_size() > 0) { - blob_pointer->Reshape(param.input_dim(top_id * 4), - param.input_dim(top_id * 4 + 1), - param.input_dim(top_id * 4 + 2), - param.input_dim(top_id * 4 + 3)); - } else { - blob_pointer->Reshape(param.input_shape(top_id)); - } - net_input_blob_indices_.push_back(blob_id); - net_input_blobs_.push_back(blob_pointer.get()); - } else { - top_id_vecs_[layer_id].push_back(blob_id); - top_vecs_[layer_id].push_back(blob_pointer.get()); - } + top_id_vecs_[layer_id].push_back(blob_id); + top_vecs_[layer_id].push_back(blob_pointer.get()); } if (available_blobs) { available_blobs->insert(blob_name); } } @@ -566,11 +531,6 @@ Dtype Net::ForwardFromTo(int start, int end) { CHECK_GE(start, 0); CHECK_LT(end, layers_.size()); Dtype loss = 0; - if (debug_info_) { - for (int i = 0; i < net_input_blobs_.size(); ++i) { - InputDebugInfo(i); - } - } for (int i = start; i <= end; ++i) { // LOG(ERROR) << "Forwarding " << layer_names_[i]; Dtype layer_loss = layers_[i]->Forward(bottom_vecs_[i], top_vecs_[i]); @@ -591,7 +551,7 @@ Dtype Net::ForwardTo(int end) { } template -const vector*>& Net::ForwardPrefilled(Dtype* loss) { +const vector*>& Net::Forward(Dtype* loss) { if (loss != NULL) { *loss = ForwardFromTo(0, layers_.size() - 1); } else { @@ -600,37 +560,6 @@ const vector*>& Net::ForwardPrefilled(Dtype* loss) { return net_output_blobs_; } -template -const vector*>& Net::Forward( - const vector*> & bottom, Dtype* loss) { - // Copy bottom to internal bottom - for (int i = 0; i < bottom.size(); ++i) { - net_input_blobs_[i]->CopyFrom(*bottom[i]); - } - return ForwardPrefilled(loss); -} - -template -string Net::Forward(const string& input_blob_protos, Dtype* loss) { - BlobProtoVector blob_proto_vec; - if (net_input_blobs_.size()) { - blob_proto_vec.ParseFromString(input_blob_protos); - CHECK_EQ(blob_proto_vec.blobs_size(), net_input_blobs_.size()) - << "Incorrect input size."; - for (int i = 0; i < blob_proto_vec.blobs_size(); ++i) { - net_input_blobs_[i]->FromProto(blob_proto_vec.blobs(i)); - } - } - ForwardPrefilled(loss); - blob_proto_vec.Clear(); - for (int i = 0; i < net_output_blobs_.size(); ++i) { - net_output_blobs_[i]->ToProto(blob_proto_vec.add_blobs()); - } - string output; - blob_proto_vec.SerializeToString(&output); - return output; -} - template void Net::BackwardFromTo(int start, int end) { CHECK_GE(end, 0); @@ -644,17 +573,6 @@ void Net::BackwardFromTo(int start, int end) { } } -template -void Net::InputDebugInfo(const int input_id) { - const Blob& blob = *net_input_blobs_[input_id]; - const string& blob_name = blob_names_[net_input_blob_indices_[input_id]]; - const Dtype data_abs_val_mean = blob.asum_data() / blob.count(); - LOG_IF(INFO, Caffe::root_solver()) - << " [Forward] " - << "Input " << blob_name << " data: " << data_abs_val_mean; -} - -template void Net::ForwardDebugInfo(const int layer_id) { for (int top_id = 0; top_id < top_vecs_[layer_id].size(); ++top_id) { const Blob& blob = *top_vecs_[layer_id][top_id]; @@ -912,9 +830,6 @@ void Net::ToProto(NetParameter* param, bool write_diff) const { param->Clear(); param->set_name(name_); // Add bottom and top - for (int i = 0; i < net_input_blob_indices_.size(); ++i) { - param->add_input(blob_names_[net_input_blob_indices_[i]]); - } DLOG(INFO) << "Serializing " << layers_.size() << " layers"; for (int i = 0; i < layers_.size(); ++i) { LayerParameter* layer_param = param->add_layer(); diff --git a/src/caffe/solver.cpp b/src/caffe/solver.cpp index d3bc7361dd5..0d898549adb 100644 --- a/src/caffe/solver.cpp +++ b/src/caffe/solver.cpp @@ -191,7 +191,6 @@ void Solver::InitTestNets() { template void Solver::Step(int iters) { - vector*> bottom_vec; const int start_iter = iter_; const int stop_iter = iter_ + iters; int average_loss = this->param_.average_loss(); @@ -219,7 +218,7 @@ void Solver::Step(int iters) { // accumulate the loss and gradient Dtype loss = 0; for (int i = 0; i < param_.iter_size(); ++i) { - loss += net_->ForwardBackward(bottom_vec); + loss += net_->ForwardBackward(); } loss /= param_.iter_size(); // average the loss across iterations for smoothed reporting @@ -316,7 +315,7 @@ void Solver::Solve(const char* resume_file) { // display the loss, which is computed in the forward pass. if (param_.display() && iter_ % param_.display() == 0) { Dtype loss; - net_->ForwardPrefilled(&loss); + net_->Forward(&loss); LOG(INFO) << "Iteration " << iter_ << ", loss = " << loss; } if (param_.test_interval() && iter_ % param_.test_interval() == 0) { @@ -343,7 +342,6 @@ void Solver::Test(const int test_net_id) { ShareTrainedLayersWith(net_.get()); vector test_score; vector test_score_output_id; - vector*> bottom_vec; const shared_ptr >& test_net = test_nets_[test_net_id]; Dtype loss = 0; for (int i = 0; i < param_.test_iter(test_net_id); ++i) { @@ -364,7 +362,7 @@ void Solver::Test(const int test_net_id) { Dtype iter_loss; const vector*>& result = - test_net->Forward(bottom_vec, &iter_loss); + test_net->Forward(&iter_loss); if (param_.test_compute_loss()) { loss += iter_loss; } diff --git a/src/caffe/test/test_gradient_based_solver.cpp b/src/caffe/test/test_gradient_based_solver.cpp index 84c6747f61a..09ec3a7e918 100644 --- a/src/caffe/test/test_gradient_based_solver.cpp +++ b/src/caffe/test/test_gradient_based_solver.cpp @@ -185,9 +185,8 @@ class GradientBasedSolverTest : public MultiDeviceTest { this->InitSolverFromProtoString(proto.str()); if (from_snapshot != NULL) { this->solver_->Restore(from_snapshot); - vector*> empty_bottom_vec; for (int i = 0; i < this->solver_->iter(); ++i) { - this->solver_->net()->Forward(empty_bottom_vec); + this->solver_->net()->Forward(); } } if (devices == 1) { @@ -231,8 +230,7 @@ class GradientBasedSolverTest : public MultiDeviceTest { // Run a forward pass, and manually compute the update values from the // result. Net& net = *this->solver_->net(); - vector*> empty_bottom_vec; - net.Forward(empty_bottom_vec); + net.Forward(); ASSERT_TRUE(net.has_blob("data")); const Blob& data = *net.blob_by_name("data"); ASSERT_TRUE(net.has_blob("targets")); diff --git a/src/caffe/test/test_net.cpp b/src/caffe/test/test_net.cpp index ab4afba1a93..0751e3d57a0 100644 --- a/src/caffe/test/test_net.cpp +++ b/src/caffe/test/test_net.cpp @@ -555,11 +555,14 @@ class NetTest : public MultiDeviceTest { virtual void InitReshapableNet() { const string& proto = "name: 'ReshapableNetwork' " - "input: 'data' " - "input_dim: 1 " - "input_dim: 3 " - "input_dim: 100 " - "input_dim: 100 " + "layer { " + " name: 'data' " + " type: 'InputData' " + " top: 'data' " + " input_data_param { " + " shape: { dim: 1 dim: 3 dim: 100 dim: 100 } " + " } " + "} " "layer { " " name: 'conv1' " " type: 'Convolution' " @@ -821,7 +824,7 @@ TYPED_TEST(NetTest, TestLossWeight) { Caffe::set_random_seed(this->seed_); const bool kForceBackward = true; this->InitUnsharedWeightsNet(NULL, NULL, kForceBackward); - const Dtype loss = this->net_->ForwardBackward(bottom); + const Dtype loss = this->net_->ForwardBackward(); const bool kCopyDiff = true; vector > > blob_grads; this->CopyNetBlobs(kCopyDiff, &blob_grads); @@ -836,7 +839,7 @@ TYPED_TEST(NetTest, TestLossWeight) { for (int i = 0; i < kNumLossWeights; ++i) { Caffe::set_random_seed(this->seed_); this->InitUnsharedWeightsNet(&kLossWeights[i], NULL, kForceBackward); - const Dtype weighted_loss = this->net_->ForwardBackward(bottom); + const Dtype weighted_loss = this->net_->ForwardBackward(); const Dtype error_margin = kErrorMargin * fabs(kLossWeights[i]); EXPECT_NEAR(loss * kLossWeights[i], weighted_loss, error_margin) << "loss weight = " << kLossWeights[i]; @@ -865,14 +868,13 @@ TYPED_TEST(NetTest, TestLossWeight) { TYPED_TEST(NetTest, TestLossWeightMidNet) { typedef typename TypeParam::Dtype Dtype; - vector*> bottom; Caffe::set_random_seed(this->seed_); const bool kForceBackward = true; Dtype loss_weight = 0; Dtype midnet_loss_weight = 1; this->InitUnsharedWeightsNet(&loss_weight, &midnet_loss_weight, kForceBackward); - const Dtype loss = this->net_->ForwardBackward(bottom); + const Dtype loss = this->net_->ForwardBackward(); const bool kCopyDiff = true; const bool kReshape = true; Blob data_grad; @@ -887,7 +889,7 @@ TYPED_TEST(NetTest, TestLossWeightMidNet) { Caffe::set_random_seed(this->seed_); this->InitUnsharedWeightsNet(&loss_weight, &kLossWeights[i], kForceBackward); - const Dtype weighted_loss = this->net_->ForwardBackward(bottom); + const Dtype weighted_loss = this->net_->ForwardBackward(); const Dtype error_margin = kErrorMargin * fabs(kLossWeights[i]); EXPECT_NEAR(loss * kLossWeights[i], weighted_loss, error_margin) << "loss weight = " << kLossWeights[i]; @@ -903,7 +905,6 @@ TYPED_TEST(NetTest, TestLossWeightMidNet) { TYPED_TEST(NetTest, TestComboLossWeight) { typedef typename TypeParam::Dtype Dtype; - vector*> bottom; Dtype loss_weight; Dtype midnet_loss_weight; const bool kForceBackward = true; @@ -916,7 +917,7 @@ TYPED_TEST(NetTest, TestComboLossWeight) { Caffe::set_random_seed(this->seed_); this->InitUnsharedWeightsNet(&loss_weight, &midnet_loss_weight, kForceBackward); - const Dtype loss = this->net_->ForwardBackward(bottom); + const Dtype loss = this->net_->ForwardBackward(); const bool kCopyDiff = true; vector > > blob_grads; this->CopyNetBlobs(kCopyDiff, &blob_grads); @@ -928,7 +929,7 @@ TYPED_TEST(NetTest, TestComboLossWeight) { Caffe::set_random_seed(this->seed_); this->InitUnsharedWeightsNet(&loss_weight, &midnet_loss_weight, kForceBackward); - const Dtype loss_main_2 = this->net_->ForwardBackward(bottom); + const Dtype loss_main_2 = this->net_->ForwardBackward(); vector > > blob_grads_loss_2; this->CopyNetBlobs(kCopyDiff, &blob_grads_loss_2); vector > > param_grads_loss_2; @@ -939,7 +940,7 @@ TYPED_TEST(NetTest, TestComboLossWeight) { Caffe::set_random_seed(this->seed_); this->InitUnsharedWeightsNet(&loss_weight, &midnet_loss_weight, kForceBackward); - const Dtype loss_main_3 = this->net_->ForwardBackward(bottom); + const Dtype loss_main_3 = this->net_->ForwardBackward(); const vector > >& blob_grads_loss_3 = this->net_->blobs(); ASSERT_EQ(blob_grads.size(), blob_grads_loss_3.size()); @@ -974,7 +975,7 @@ TYPED_TEST(NetTest, TestComboLossWeight) { Caffe::set_random_seed(this->seed_); this->InitUnsharedWeightsNet(&loss_weight, &midnet_loss_weight, kForceBackward); - const Dtype loss_midnet_2 = this->net_->ForwardBackward(bottom); + const Dtype loss_midnet_2 = this->net_->ForwardBackward(); this->CopyNetBlobs(kCopyDiff, &blob_grads_loss_2); this->CopyNetParams(kCopyDiff, ¶m_grads_loss_2); @@ -983,7 +984,7 @@ TYPED_TEST(NetTest, TestComboLossWeight) { Caffe::set_random_seed(this->seed_); this->InitUnsharedWeightsNet(&loss_weight, &midnet_loss_weight, kForceBackward); - const Dtype loss_midnet_3 = this->net_->ForwardBackward(bottom); + const Dtype loss_midnet_3 = this->net_->ForwardBackward(); const vector > >& blob_grads_midnet_loss_3 = this->net_->blobs(); ASSERT_EQ(blob_grads.size(), blob_grads_midnet_loss_3.size()); @@ -1032,40 +1033,35 @@ TYPED_TEST(NetTest, TestComboLossWeight) { } TYPED_TEST(NetTest, TestBackwardWithAccuracyLayer) { - typedef typename TypeParam::Dtype Dtype; const bool kForceBackward = false; const bool kAccuracyLayer = true; this->InitTinyNet(kForceBackward, kAccuracyLayer); EXPECT_TRUE(this->net_->has_blob("accuracy")); - vector*> bottom; // Test that we can do Backward even though we have an 'Accuracy' layer. - this->net_->ForwardBackward(bottom); + this->net_->ForwardBackward(); } TYPED_TEST(NetTest, TestUnsharedWeightsDataNet) { typedef typename TypeParam::Dtype Dtype; this->InitUnsharedWeightsNet(); - vector*> bottom; Dtype loss; - this->net_->Forward(bottom, &loss); + this->net_->Forward(&loss); EXPECT_GT(loss, 0); } TYPED_TEST(NetTest, TestSharedWeightsDataNet) { typedef typename TypeParam::Dtype Dtype; this->InitSharedWeightsNet(); - vector*> bottom; Dtype loss; - this->net_->Forward(bottom, &loss); + this->net_->Forward(&loss); EXPECT_FLOAT_EQ(loss, 0); } TYPED_TEST(NetTest, TestUnsharedWeightsDiffNet) { typedef typename TypeParam::Dtype Dtype; this->InitUnsharedWeightsNet(); - vector*> bottom; Net* net = this->net_.get(); - net->Forward(bottom); + net->Forward(); net->Backward(); Layer* ip1_layer = net->layer_by_name("innerproduct1").get(); Layer* ip2_layer = net->layer_by_name("innerproduct2").get(); @@ -1081,10 +1077,9 @@ TYPED_TEST(NetTest, TestUnsharedWeightsDiffNet) { TYPED_TEST(NetTest, TestSharedWeightsDiffNet) { typedef typename TypeParam::Dtype Dtype; this->InitSharedWeightsNet(); - vector*> bottom; Net* net = this->net_.get(); Dtype loss; - net->Forward(bottom, &loss); + net->Forward(&loss); net->Backward(); EXPECT_FLOAT_EQ(loss, 0); Layer* ip1_layer = net->layer_by_name("innerproduct1").get(); @@ -1102,7 +1097,6 @@ TYPED_TEST(NetTest, TestSharedWeightsUpdate) { typedef typename TypeParam::Dtype Dtype; Caffe::set_random_seed(this->seed_); this->InitDiffDataSharedWeightsNet(); - vector*> bottom; EXPECT_EQ(this->net_->layer_names()[1], "innerproduct1"); EXPECT_EQ(this->net_->layer_names()[2], "innerproduct2"); Blob* ip1_weights = this->net_->layers()[1]->blobs()[0].get(); @@ -1111,7 +1105,7 @@ TYPED_TEST(NetTest, TestSharedWeightsUpdate) { // locations. EXPECT_EQ(ip1_weights->cpu_data(), ip2_weights->cpu_data()); EXPECT_EQ(ip1_weights->cpu_diff(), ip2_weights->cpu_diff()); - this->net_->Forward(bottom); + this->net_->Forward(); this->net_->Backward(); // Compute the expected update as the data minus the two diffs. Blob shared_params; @@ -1146,7 +1140,7 @@ TYPED_TEST(NetTest, TestSharedWeightsUpdate) { // locations in memory. EXPECT_NE(ip1_weights->cpu_data(), ip2_weights->cpu_data()); EXPECT_NE(ip1_weights->cpu_diff(), ip2_weights->cpu_diff()); - this->net_->Forward(bottom); + this->net_->Forward(); this->net_->Backward(); // Compute the expected update. Blob unshared_params1; @@ -1186,7 +1180,6 @@ TYPED_TEST(NetTest, TestSharedWeightsResume) { // Create a net with weight sharing; Update it once. Caffe::set_random_seed(this->seed_); this->InitDiffDataSharedWeightsNet(); - vector*> bottom; EXPECT_EQ(this->net_->layer_names()[1], "innerproduct1"); EXPECT_EQ(this->net_->layer_names()[2], "innerproduct2"); Blob* ip1_weights = this->net_->layers()[1]->blobs()[0].get(); @@ -1195,7 +1188,7 @@ TYPED_TEST(NetTest, TestSharedWeightsResume) { // locations. EXPECT_EQ(ip1_weights->cpu_data(), ip2_weights->cpu_data()); EXPECT_EQ(ip1_weights->cpu_diff(), ip2_weights->cpu_diff()); - this->net_->ForwardBackward(bottom); + this->net_->ForwardBackward(); this->net_->Update(); Blob shared_params; const bool kReshape = true; @@ -1228,7 +1221,6 @@ TYPED_TEST(NetTest, TestSharedWeightsResume) { TYPED_TEST(NetTest, TestParamPropagateDown) { typedef typename TypeParam::Dtype Dtype; - vector*> bottom; const bool kBiasTerm = true, kForceBackward = false; const Dtype* kLossWeight1 = NULL; const Dtype* kLossWeight2 = NULL; @@ -1238,7 +1230,7 @@ TYPED_TEST(NetTest, TestParamPropagateDown) { Dtype blobs_lr_w1 = 1, blobs_lr_w2 = 1, blobs_lr_b1 = 2, blobs_lr_b2 = 2; this->InitUnsharedWeightsNet(kLossWeight1, kLossWeight2, kForceBackward, kBiasTerm, blobs_lr_w1, blobs_lr_w2, blobs_lr_b1, blobs_lr_b2); - this->net_->Forward(bottom); + this->net_->Forward(); this->net_->Backward(); const vector > >& params = this->net_->params(); const int num_params = params.size(); @@ -1258,7 +1250,7 @@ TYPED_TEST(NetTest, TestParamPropagateDown) { blobs_lr_w1 *= 2, blobs_lr_w2 *= 2, blobs_lr_b1 *= 2, blobs_lr_b2 *= 2; this->InitUnsharedWeightsNet(kLossWeight1, kLossWeight2, kForceBackward, kBiasTerm, blobs_lr_w1, blobs_lr_w2, blobs_lr_b1, blobs_lr_b2); - this->net_->Forward(bottom); + this->net_->Forward(); this->net_->Backward(); const vector > >& params2 = this->net_->params(); ASSERT_EQ(num_params, params2.size()); @@ -1274,7 +1266,7 @@ TYPED_TEST(NetTest, TestParamPropagateDown) { blobs_lr_w1 = 1, blobs_lr_w2 = 0, blobs_lr_b1 = 0, blobs_lr_b2 = 1; this->InitUnsharedWeightsNet(kLossWeight1, kLossWeight2, kForceBackward, kBiasTerm, blobs_lr_w1, blobs_lr_w2, blobs_lr_b1, blobs_lr_b2); - this->net_->Forward(bottom); + this->net_->Forward(); this->net_->Backward(); const vector > >& params3 = this->net_->params(); ASSERT_EQ(num_params, params3.size()); @@ -1293,7 +1285,7 @@ TYPED_TEST(NetTest, TestParamPropagateDown) { blobs_lr_w1 = 0, blobs_lr_w2 = 1, blobs_lr_b1 = 1, blobs_lr_b2 = 0; this->InitUnsharedWeightsNet(kLossWeight1, kLossWeight2, kForceBackward, kBiasTerm, blobs_lr_w1, blobs_lr_w2, blobs_lr_b1, blobs_lr_b2); - this->net_->Forward(bottom); + this->net_->Forward(); this->net_->Backward(); const vector > >& params4 = this->net_->params(); ASSERT_EQ(num_params, params4.size()); @@ -1315,7 +1307,7 @@ TYPED_TEST(NetTest, TestFromTo) { // Run Forward and Backward, recording the data diff and loss. Blob data; data.ReshapeLike(*this->net_->blob_by_name("data")); - this->net_->ForwardPrefilled(); + this->net_->Forward(); this->net_->Backward(); data.CopyFrom(*this->net_->blob_by_name("data"), true, true); const Dtype *loss_ptr = this->net_->output_blobs()[0]->cpu_data(); @@ -2277,12 +2269,12 @@ TYPED_TEST(NetTest, TestReshape) { filler.Fill(&blob2); this->InitReshapableNet(); - Blob* input_blob = this->net_->input_blobs()[0]; + shared_ptr > input_blob = this->net_->blob_by_name("data"); Blob* output_blob = this->net_->output_blobs()[0]; input_blob->Reshape(blob1.num(), blob1.channels(), blob1.height(), blob1.width()); caffe_copy(blob1.count(), blob1.cpu_data(), input_blob->mutable_cpu_data()); - this->net_->ForwardPrefilled(); + this->net_->Forward(); // call backward just to make sure it runs this->net_->Backward(); Blob output1(output_blob->num(), output_blob->channels(), @@ -2293,7 +2285,7 @@ TYPED_TEST(NetTest, TestReshape) { input_blob->Reshape(blob2.num(), blob2.channels(), blob2.height(), blob2.width()); caffe_copy(blob2.count(), blob2.cpu_data(), input_blob->mutable_cpu_data()); - this->net_->ForwardPrefilled(); + this->net_->Forward(); this->net_->Backward(); Blob output2(output_blob->num(), output_blob->channels(), output_blob->height(), output_blob->width()); @@ -2303,7 +2295,7 @@ TYPED_TEST(NetTest, TestReshape) { input_blob->Reshape(blob1.num(), blob1.channels(), blob1.height(), blob1.width()); caffe_copy(blob1.count(), blob1.cpu_data(), input_blob->mutable_cpu_data()); - this->net_->ForwardPrefilled(); + this->net_->Forward(); this->net_->Backward(); for (int i = 0; i < output1.count(); ++i) { EXPECT_FLOAT_EQ(*(output1.cpu_data() + i), *(output_blob->cpu_data() + i)); @@ -2312,7 +2304,7 @@ TYPED_TEST(NetTest, TestReshape) { input_blob->Reshape(blob2.num(), blob2.channels(), blob2.height(), blob2.width()); caffe_copy(blob2.count(), blob2.cpu_data(), input_blob->mutable_cpu_data()); - this->net_->ForwardPrefilled(); + this->net_->Forward(); this->net_->Backward(); for (int i = 0; i < output2.count(); ++i) { EXPECT_FLOAT_EQ(*(output2.cpu_data() + i), *(output_blob->cpu_data() + i)); diff --git a/src/caffe/test/test_split_layer.cpp b/src/caffe/test/test_split_layer.cpp index be5204bfc3e..cffdf824bc6 100644 --- a/src/caffe/test/test_split_layer.cpp +++ b/src/caffe/test/test_split_layer.cpp @@ -887,67 +887,6 @@ TEST_F(SplitLayerInsertionTest, TestInsertionTwoTop) { this->RunInsertionTest(input_proto, expected_output_proto); } -TEST_F(SplitLayerInsertionTest, TestInputInsertion) { - const string& input_proto = - "name: 'TestNetwork' " - "input: 'data' " - "input_dim: 10 " - "input_dim: 3 " - "input_dim: 227 " - "input_dim: 227 " - "layer { " - " name: 'innerprod1' " - " type: 'InnerProduct' " - " bottom: 'data' " - " top: 'innerprod1' " - "} " - "layer { " - " name: 'innerprod2' " - " type: 'InnerProduct' " - " bottom: 'data' " - " top: 'innerprod2' " - "} " - "layer { " - " name: 'loss' " - " type: 'EuclideanLoss' " - " bottom: 'innerprod1' " - " bottom: 'innerprod2' " - "} "; - const string& expected_output_proto = - "name: 'TestNetwork' " - "input: 'data' " - "input_dim: 10 " - "input_dim: 3 " - "input_dim: 227 " - "input_dim: 227 " - "layer { " - " name: 'data_input_0_split' " - " type: 'Split' " - " bottom: 'data' " - " top: 'data_input_0_split_0' " - " top: 'data_input_0_split_1' " - "} " - "layer { " - " name: 'innerprod1' " - " type: 'InnerProduct' " - " bottom: 'data_input_0_split_0' " - " top: 'innerprod1' " - "} " - "layer { " - " name: 'innerprod2' " - " type: 'InnerProduct' " - " bottom: 'data_input_0_split_1' " - " top: 'innerprod2' " - "} " - "layer { " - " name: 'loss' " - " type: 'EuclideanLoss' " - " bottom: 'innerprod1' " - " bottom: 'innerprod2' " - "} "; - this->RunInsertionTest(input_proto, expected_output_proto); -} - TEST_F(SplitLayerInsertionTest, TestWithInPlace) { const string& input_proto = "name: 'TestNetwork' " diff --git a/tools/caffe.cpp b/tools/caffe.cpp index 305cfc3635d..d0166efc7d1 100644 --- a/tools/caffe.cpp +++ b/tools/caffe.cpp @@ -239,14 +239,13 @@ int test() { caffe_net.CopyTrainedLayersFrom(FLAGS_weights); LOG(INFO) << "Running for " << FLAGS_iterations << " iterations."; - vector* > bottom_vec; vector test_score_output_id; vector test_score; float loss = 0; for (int i = 0; i < FLAGS_iterations; ++i) { float iter_loss; const vector*>& result = - caffe_net.Forward(bottom_vec, &iter_loss); + caffe_net.Forward(&iter_loss); loss += iter_loss; int idx = 0; for (int j = 0; j < result.size(); ++j) { @@ -310,7 +309,7 @@ int time() { // Note that for the speed benchmark, we will assume that the network does // not take any input blobs. float initial_loss; - caffe_net.Forward(vector*>(), &initial_loss); + caffe_net.Forward(&initial_loss); LOG(INFO) << "Initial loss: " << initial_loss; LOG(INFO) << "Performing Backward"; caffe_net.Backward(); diff --git a/tools/extract_features.cpp b/tools/extract_features.cpp index 084c9bf88df..b12f41f350a 100644 --- a/tools/extract_features.cpp +++ b/tools/extract_features.cpp @@ -137,10 +137,9 @@ int feature_extraction_pipeline(int argc, char** argv) { Datum datum; const int kMaxKeyStrLength = 100; char key_str[kMaxKeyStrLength]; - std::vector*> input_vec; std::vector image_indices(num_features, 0); for (int batch_index = 0; batch_index < num_mini_batches; ++batch_index) { - feature_extraction_net->Forward(input_vec); + feature_extraction_net->Forward(); for (int i = 0; i < num_features; ++i) { const shared_ptr > feature_blob = feature_extraction_net ->blob_by_name(blob_names[i]);