Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions stan/math/rev/fun/log_sum_exp.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ class log_sum_exp_vv_vari : public op_vv_vari {
log_sum_exp_vv_vari(vari* avi, vari* bvi)
: op_vv_vari(log_sum_exp(avi->val_, bvi->val_), avi, bvi) {}
void chain() {
avi_->adj_ += adj_ * calculate_chain(avi_->val_, val_);
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be possible to get rid of calculate_chain everywhere? It seems like a strange function.

Also used here (which should be fixed as part of this pull):

avi_->adj_ += adj_ * calculate_chain(avi_->val_, val_);

In stan/math/rev/fun/log_diff_exp.hpp and in stan/math/rev/fun/log1p_exp.hpp.

bvi_->adj_ += adj_ * calculate_chain(bvi_->val_, val_);
avi_->adj_ += adj_ / (1 + exp(bvi_->val_ - avi_->val_));
bvi_->adj_ += adj_ / (1 + exp(avi_->val_ - bvi_->val_));
}
};
class log_sum_exp_vd_vari : public op_vd_vari {
Expand Down
28 changes: 28 additions & 0 deletions test/unit/math/rev/fun/log_sum_exp_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
#include <stan/math/rev.hpp>
#include <gtest/gtest.h>
#include <test/unit/math/rev/fun/util.hpp>
#include <test/unit/math/rev/util.hpp>

TEST(log_sum_exp_tests, large_values) {
using stan::math::var;

var a = 1e50;
var output = stan::math::log_sum_exp(a, a);
output.grad();
EXPECT_FLOAT_EQ(output.val(), log(2.0) + value_of(a));
EXPECT_FLOAT_EQ(a.adj(), 1.0);

var a1 = 1e50;
var a2 = 1;
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Once you remove the other calculate_chain thing we'll need some tests where a1 is a var and a2 is a double.

var output2 = stan::math::log_sum_exp(a1, a2);
output2.grad();
EXPECT_FLOAT_EQ(a1.adj(), 1.0);
EXPECT_FLOAT_EQ(a2.adj(), 0.0);

var a3 = 1;
var a4 = 1e50;
var output3 = stan::math::log_sum_exp(a3, a4);
output3.grad();
EXPECT_FLOAT_EQ(a3.adj(), 0.0);
EXPECT_FLOAT_EQ(a4.adj(), 1.0);
}