Skip to content

Commit

Permalink
improve negativity check in formula generation (#358)
Browse files Browse the repository at this point in the history
  • Loading branch information
ckrause authored Dec 11, 2024
1 parent d6e96be commit 7b0988f
Show file tree
Hide file tree
Showing 8 changed files with 45 additions and 19 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@ To install or update LODA, please follow the [installation instructions](https:/

## [Unreleased]

### Enhancements

* Improve negativity check in formula generation

## v24.12.8

### Enhancements
Expand Down
25 changes: 17 additions & 8 deletions src/form/expression_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -328,18 +328,19 @@ bool ExpressionUtil::hasNonRecursiveFunctionReference(
}
}

bool ExpressionUtil::canBeNegative(const Expression& e) {
bool ExpressionUtil::canBeNegative(const Expression& e, int64_t offset) {
switch (e.type) {
case Expression::Type::CONSTANT:
return e.value < Number::ZERO;
case Expression::Type::PARAMETER:
return false;
return offset < 0;
case Expression::Type::LOCAL:
return true;
case Expression::Type::FUNCTION:
if (e.name == "max" &&
std::any_of(e.children.begin(), e.children.end(),
[](const Expression& c) { return !canBeNegative(c); })) {
if (e.name == "max" && std::any_of(e.children.begin(), e.children.end(),
[&](const Expression& c) {
return !canBeNegative(c, offset);
})) {
return false;
} else if (e.name == "binomial" || e.name == "floor" ||
e.name == "truncate") {
Expand All @@ -358,15 +359,23 @@ bool ExpressionUtil::canBeNegative(const Expression& e) {
return e.children[1].value.odd();
}
}
case Expression::Type::SUM:
case Expression::Type::SUM: {
if (e.children.size() == 2 &&
e.children[0].type == Expression::Type::PARAMETER &&
e.children[1].type == Expression::Type::CONSTANT) {
return Number(-offset) > e.children[1].value;
}
break; // infer from children
}
case Expression::Type::PRODUCT:
case Expression::Type::FRACTION:
case Expression::Type::MODULUS:
case Expression::Type::IF:
break; // infer from children
}
return std::any_of(e.children.begin(), e.children.end(),
[](const Expression& c) { return canBeNegative(c); });
return std::any_of(
e.children.begin(), e.children.end(),
[&](const Expression& c) { return canBeNegative(c, offset); });
}

void ExpressionUtil::collectNames(const Expression& e, Expression::Type type,
Expand Down
2 changes: 1 addition & 1 deletion src/form/expression_util.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class ExpressionUtil {
const Expression& e, const std::vector<std::string>& names,
int64_t max_offset);

static bool canBeNegative(const Expression& e);
static bool canBeNegative(const Expression& e, int64_t offset);

static void collectNames(const Expression& e, Expression::Type type,
std::set<std::string>& target);
Expand Down
17 changes: 9 additions & 8 deletions src/form/formula_gen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,12 @@ Expression FormulaGenerator::operandToExpression(Operand op) const {
throw std::runtime_error("internal error"); // unreachable
}

Expression divToFraction(const Expression& numerator,
const Expression& denominator) {
Expression FormulaGenerator::divToFraction(
const Expression& numerator, const Expression& denominator) const {
Expression frac(Expression::Type::FRACTION, "", {numerator, denominator});
std::string func = "floor";
if (ExpressionUtil::canBeNegative(numerator) ||
ExpressionUtil::canBeNegative(denominator)) {
if (ExpressionUtil::canBeNegative(numerator, offset) ||
ExpressionUtil::canBeNegative(denominator, offset)) {
func = "truncate";
}
Expression wrapper(Expression::Type::FUNCTION, func, {frac});
Expand Down Expand Up @@ -105,7 +105,7 @@ bool FormulaGenerator::update(const Operation& op) {
}
case Operation::Type::POW: {
res = Expression(Expression::Type::POWER, "", {prevTarget, source});
if (ExpressionUtil::canBeNegative(source)) {
if (ExpressionUtil::canBeNegative(source, offset)) {
Expression wrapper(Expression::Type::FUNCTION, "truncate", {res});
res = wrapper;
}
Expand All @@ -114,8 +114,8 @@ bool FormulaGenerator::update(const Operation& op) {
case Operation::Type::MOD: {
auto c1 = prevTarget;
auto c2 = source;
if (ExpressionUtil::canBeNegative(c1) ||
ExpressionUtil::canBeNegative(c2)) {
if (ExpressionUtil::canBeNegative(c1, offset) ||
ExpressionUtil::canBeNegative(c2, offset)) {
Expression wrapper(Expression::Type::SUM);
wrapper.newChild(c1);
wrapper.newChild(Expression::Type::PRODUCT);
Expand Down Expand Up @@ -364,7 +364,7 @@ void FormulaGenerator::prepareForPostLoop(
right = preloopExprs.at(cell);
} else {
auto safe_param = preloopCounter;
if (ExpressionUtil::canBeNegative(safe_param)) {
if (ExpressionUtil::canBeNegative(safe_param, offset)) {
auto tmp = safe_param;
safe_param = Expression(Expression::Type::FUNCTION, "max",
{tmp, ExpressionUtil::newConstant(0)});
Expand Down Expand Up @@ -463,6 +463,7 @@ bool FormulaGenerator::generate(const Program& p, int64_t id, Formula& result,
}
formula.clear();
freeNameIndex = 0;
offset = ProgramUtil::getOffset(p);
if (!generateSingle(p)) {
return false;
}
Expand Down
4 changes: 4 additions & 0 deletions src/form/formula_gen.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ class FormulaGenerator {

void initFormula(int64_t numCells, bool useIncEval);

Expression divToFraction(const Expression& numerator,
const Expression& denominator) const;

bool update(const Operation& op);

bool update(const Program& p);
Expand All @@ -59,4 +62,5 @@ class FormulaGenerator {
Formula formula;
std::map<int64_t, std::string> cellNames;
size_t freeNameIndex;
int64_t offset;
};
3 changes: 2 additions & 1 deletion src/form/pari.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ bool convertExprToPari(Expression& expr, const Formula& f, bool as_vector) {
}
if (expr.type == Expression::Type::FUNCTION && expr.name == "binomial") {
// TODO: check feedback from PARI team to avoid this limitation
if (ExpressionUtil::canBeNegative(expr.children.at(1))) {
// note also that we should use the proper offset here
if (ExpressionUtil::canBeNegative(expr.children.at(1), 0)) {
return false;
}
}
Expand Down
3 changes: 2 additions & 1 deletion tests/formula/formula.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,15 @@ A000165: a(n) = b(2*n), b(n) = n*b(n-2), b(1) = 1, b(0) = 1
A000168: a(n) = truncate((2*A151383(n))/(n+2)), A151383(n) = floor((floor(binomial(2*n,n)/(n+1))*3^(n+1))/3)
A000212: a(n) = floor((n^2)/3)
A000247: a(n) = 2^n-n-2
A000272: a(n) = n^max(n-2,0)
A000276: a(n) = truncate((2*c(n+1)*(n+3))/2), b(n) = b(n-1)*(n+1), b(2) = 6, b(1) = 2, b(0) = 1, c(n) = c(n-1)*(n+1)+b(n-1), c(2) = 5, c(1) = 1, c(0) = 0
A000278: a(n) = (-a(n-2))^2+a(n-1), a(1) = 1, a(0) = 0
A000280: a(n) = a(n-2)^3+a(n-1), a(1) = 1, a(0) = 0
A000284: a(n) = a(n-1)^3+a(n-2), a(3) = 2, a(2) = 1, a(1) = 1, a(0) = 0
A000289: a(n) = a(n-1)*(a(n-1)-3)+3, a(2) = 7, a(1) = 4, a(0) = 1
A000290: a(n) = n^2
A000295: a(n) = 2^n-n-1
A000317: a(n) = b(max(n-1,0)), b(n) = b(n-2)^2+max(-b(n-2)+b(n-1)-1,0)*b(n-1)+b(n-1), b(3) = 7, b(2) = 3, b(1) = 2, b(0) = 1
A000317: a(n) = b(n-1), b(n) = b(n-2)^2+max(-b(n-2)+b(n-1)-1,0)*b(n-1)+b(n-1), b(3) = 7, b(2) = 3, b(1) = 2, b(0) = 1
A000321: a(n) = -a(n-2)*(2*n-2)-a(n-1), a(2) = -1, a(1) = -1, a(0) = 1
A000407: a(n) = a(n-1)*(4*n+2), a(1) = 6, a(0) = 1
A000463: a(n) = gcd(n+1,floor(n/2)+1)*(floor(n/2)+1)
Expand Down
6 changes: 6 additions & 0 deletions tests/programs/oeis/000/A000272.asm
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
; A000272: Number of trees on n labeled nodes: n^(n-2) with a(0)=1.
; 1,1,1,3,16,125,1296,16807,262144,4782969,100000000,2357947691,61917364224,1792160394037,56693912375296,1946195068359375,72057594037927936,2862423051509815793,121439531096594251776,5480386857784802185939,262144000000000000000000,13248496640331026125580781,705429498686404044207947776,39471584120695485887249589623,2315513501476187716057433112576,142108547152020037174224853515625,9106685769537214956799814036094976,608266787713357709119683992618861307,42277452950578284263485622772148731904

mov $1,$0
trn $1,2
pow $0,$1

0 comments on commit 7b0988f

Please sign in to comment.