[jOOQ/jOOQ#11547] Add test for empty sets and warning about "catastrophic cancellation"

This commit is contained in:
Lukas Eder 2021-03-03 12:12:16 +01:00
parent aca6a4181b
commit 74d6896db6
10 changed files with 162 additions and 73 deletions

View File

@ -51,6 +51,8 @@ import static org.jooq.impl.Keywords.K_DISTINCT;
import static org.jooq.impl.Keywords.K_FILTER;
import static org.jooq.impl.Keywords.K_ORDER_BY;
import static org.jooq.impl.Keywords.K_WHERE;
import static org.jooq.impl.SQLDataType.DOUBLE;
import static org.jooq.impl.SQLDataType.NUMERIC;
import java.util.Arrays;
import java.util.Collection;
@ -301,16 +303,30 @@ implements
}
/**
* Type safe <code>NVL2(y, x, null)</code> for REGR emulations.
* Type safe <code>NVL2(y, x, null)</code> for statistical function
* emulations.
*/
final <U extends Number> Field<U> x(Field<U> x, Field<? extends Number> y) {
return DSL.nvl2(y, x, DSL.NULL(x.getDataType()));
}
/**
* Type safe <code>NVL2(x, y, null)</code> for REGR emulations.
* Type safe <code>NVL2(x, y, null)</code> for statistical function
* emulations.
*/
final <U extends Number> Field<U> y(Field<? extends Number> x, Field<U> y) {
return DSL.nvl2(x, y, DSL.NULL(y.getDataType()));
}
/**
* The data type to use in casts when emulating statistical functions.
*/
final DataType<? extends Number> d(Context<?> ctx) {
switch (ctx.family()) {
case SQLITE:
return DOUBLE;
default:
return NUMERIC;
}
}
}

View File

@ -96,7 +96,7 @@ extends
Field<? extends Number> x = (Field) getArguments().get(0);
Field<? extends Number> y = (Field) getArguments().get(1);
ctx.visit(fo(DSL.sum(x.times(y))).minus(fo(DSL.sum(x(x, y))).times(fo(DSL.sum(y(x, y)))).div(fo(DSL.count(x.plus(y))).cast(NUMERIC))).div(fo(DSL.count(x.plus(y))).cast(NUMERIC)));
ctx.visit(fo(DSL.sum(x.times(y))).minus(fo(DSL.sum(x(x, y))).times(fo(DSL.sum(y(x, y)))).div(fo(DSL.count(x.plus(y))).cast(d(ctx)))).div(fo(DSL.count(x.plus(y))).cast(d(ctx))));
}
else
super.accept(ctx);

View File

@ -96,7 +96,7 @@ extends
Field<? extends Number> x = (Field) getArguments().get(0);
Field<? extends Number> y = (Field) getArguments().get(1);
ctx.visit(fo(DSL.sum(x.times(y))).minus(fo(DSL.sum(x(x, y))).times(fo(DSL.sum(y(x, y)))).div(fo(DSL.count(x.plus(y))).cast(NUMERIC))).div(fo(DSL.count(x.plus(y))).cast(NUMERIC).minus(DSL.inline(1))));
ctx.visit(fo(DSL.sum(x.times(y))).minus(fo(DSL.sum(x(x, y))).times(fo(DSL.sum(y(x, y)))).div(fo(DSL.count(x.plus(y))).cast(d(ctx)))).div(fo(DSL.count(x.plus(y))).cast(d(ctx)).minus(DSL.inline(1))));
}
else
super.accept(ctx);

View File

@ -18720,6 +18720,11 @@ public class DSL {
/**
* The <code>CORR</code> function.
* <p>
* Calculate the correlation coefficient. This standard SQL function may be supported
* natively, or emulated using {@link #covarPop(Field, Field)} and {@link #stddevPop(Field)}.
* If an emulation is applied, beware of the risk of "<a href="https://en.wikipedia.org/wiki/Catastrophic_cancellation">Catastrophic
* cancellation</a>" in case the calculations are performed using floating point arithmetic.
*/
@NotNull
@Support({ CUBRID, FIREBIRD, H2, HSQLDB, MARIADB, MYSQL, POSTGRES })
@ -18729,6 +18734,11 @@ public class DSL {
/**
* The <code>COVAR_SAMP</code> function.
* <p>
* Calculate the sample covariance. This standard SQL function may be supported natively,
* or emulated using {@link #sum(Field)} and {@link #count(Field)}. If an emulation
* is applied, beware of the risk of "<a href="https://en.wikipedia.org/wiki/Catastrophic_cancellation">Catastrophic
* cancellation</a>" in case the calculations are performed using floating point arithmetic.
*/
@NotNull
@Support
@ -18738,6 +18748,11 @@ public class DSL {
/**
* The <code>COVAR_POP</code> function.
* <p>
* Calculate the population covariance. This standard SQL function may be supported
* natively, or emulated using {@link #sum(Field)} and {@link #count(Field)}. If an
* emulation is applied, beware of the risk of "<a href="https://en.wikipedia.org/wiki/Catastrophic_cancellation">Catastrophic
* cancellation</a>" in case the calculations are performed using floating point arithmetic.
*/
@NotNull
@Support
@ -18756,6 +18771,11 @@ public class DSL {
/**
* The <code>REGR_AVG_X</code> function.
* <p>
* Calculate the average of the independent values (x). This standard SQL function may
* be supported natively, or emulated using {@link #sum(Field)} and {@link #count(Field)}.
* If an emulation is applied, beware of the risk of "<a href="https://en.wikipedia.org/wiki/Catastrophic_cancellation">Catastrophic
* cancellation</a>" in case the calculations are performed using floating point arithmetic.
*/
@NotNull
@Support
@ -18765,6 +18785,11 @@ public class DSL {
/**
* The <code>REGR_AVG_Y</code> function.
* <p>
* Calculate the average of the dependent values (y). This standard SQL function may
* be supported natively, or emulated using {@link #sum(Field)} and {@link #count(Field)}.
* If an emulation is applied, beware of the risk of "<a href="https://en.wikipedia.org/wiki/Catastrophic_cancellation">Catastrophic
* cancellation</a>" in case the calculations are performed using floating point arithmetic.
*/
@NotNull
@Support
@ -18774,6 +18799,11 @@ public class DSL {
/**
* The <code>REGR_COUNT</code> function.
* <p>
* Calculate the number of non-<code>NULL</code> pairs. This standard SQL function may
* be supported natively, or emulated using {@link #sum(Field)} and {@link #count(Field)}.
* If an emulation is applied, beware of the risk of "<a href="https://en.wikipedia.org/wiki/Catastrophic_cancellation">Catastrophic
* cancellation</a>" in case the calculations are performed using floating point arithmetic.
*/
@NotNull
@Support
@ -18783,6 +18813,11 @@ public class DSL {
/**
* The <code>REGR_INTERCEPT</code> function.
* <p>
* Calculate the y intercept of the regression line. This standard SQL function may
* be supported natively, or emulated using {@link #sum(Field)} and {@link #count(Field)}.
* If an emulation is applied, beware of the risk of "<a href="https://en.wikipedia.org/wiki/Catastrophic_cancellation">Catastrophic
* cancellation</a>" in case the calculations are performed using floating point arithmetic.
*/
@NotNull
@Support
@ -18792,33 +18827,53 @@ public class DSL {
/**
* The <code>REGR_R2</code> function.
* <p>
* Calculate the coefficient of determination. This standard SQL function may be supported
* natively, or emulated using {@link #sum(Field)} and {@link #count(Field)}. If an
* emulation is applied, beware of the risk of "<a href="https://en.wikipedia.org/wiki/Catastrophic_cancellation">Catastrophic
* cancellation</a>" in case the calculations are performed using floating point arithmetic.
*/
@NotNull
@Support({ CUBRID, H2, HSQLDB, MARIADB, MYSQL, POSTGRES })
@Support({ CUBRID, FIREBIRD, H2, HSQLDB, MARIADB, MYSQL, POSTGRES })
public static AggregateFunction<BigDecimal> regrR2(Field<? extends Number> y, Field<? extends Number> x) {
return new RegrR2(y, x);
}
/**
* The <code>REGR_SLOPE</code> function.
* <p>
* Calculate the slope of the regression line. This standard SQL function may be supported
* natively, or emulated using {@link #sum(Field)} and {@link #count(Field)}. If an
* emulation is applied, beware of the risk of "<a href="https://en.wikipedia.org/wiki/Catastrophic_cancellation">Catastrophic
* cancellation</a>" in case the calculations are performed using floating point arithmetic.
*/
@NotNull
@Support({ CUBRID, H2, HSQLDB, MARIADB, MYSQL, POSTGRES })
@Support
public static AggregateFunction<BigDecimal> regrSlope(Field<? extends Number> y, Field<? extends Number> x) {
return new RegrSlope(y, x);
}
/**
* The <code>REGR_SXX</code> function.
* <p>
* Calculate the <code>REGR_SXX</code> auxiliary function. This standard SQL function
* may be supported natively, or emulated using {@link #sum(Field)} and {@link #count(Field)}.
* If an emulation is applied, beware of the risk of "<a href="https://en.wikipedia.org/wiki/Catastrophic_cancellation">Catastrophic
* cancellation</a>" in case the calculations are performed using floating point arithmetic.
*/
@NotNull
@Support({ CUBRID, H2, HSQLDB, MARIADB, MYSQL, POSTGRES })
@Support
public static AggregateFunction<BigDecimal> regrSXX(Field<? extends Number> y, Field<? extends Number> x) {
return new RegrSxx(y, x);
}
/**
* The <code>REGR_SXY</code> function.
* <p>
* Calculate the <code>REGR_SXY</code> auxiliary function. This standard SQL function
* may be supported natively, or emulated using {@link #sum(Field)} and {@link #count(Field)}.
* If an emulation is applied, beware of the risk of "<a href="https://en.wikipedia.org/wiki/Catastrophic_cancellation">Catastrophic
* cancellation</a>" in case the calculations are performed using floating point arithmetic.
*/
@NotNull
@Support
@ -18828,15 +18883,25 @@ public class DSL {
/**
* The <code>REGR_SYY</code> function.
* <p>
* Calculate the <code>REGR_SYY</code> auxiliary function. This standard SQL function
* may be supported natively, or emulated using {@link #sum(Field)} and {@link #count(Field)}.
* If an emulation is applied, beware of the risk of "<a href="https://en.wikipedia.org/wiki/Catastrophic_cancellation">Catastrophic
* cancellation</a>" in case the calculations are performed using floating point arithmetic.
*/
@NotNull
@Support({ CUBRID, H2, HSQLDB, MARIADB, MYSQL, POSTGRES })
@Support
public static AggregateFunction<BigDecimal> regrSYY(Field<? extends Number> y, Field<? extends Number> x) {
return new RegrSyy(y, x);
}
/**
* The <code>STDDEV_POP</code> function.
* <p>
* Calculate the population standard deviation. This standard SQL function may be supported
* natively, or emulated using {@link #sum(Field)} and {@link #count(Field)}. If an
* emulation is applied, beware of the risk of "<a href="https://en.wikipedia.org/wiki/Catastrophic_cancellation">Catastrophic
* cancellation</a>" in case the calculations are performed using floating point arithmetic.
*/
@NotNull
@Support({ CUBRID, H2, HSQLDB, MARIADB, MYSQL, POSTGRES })
@ -18846,6 +18911,11 @@ public class DSL {
/**
* The <code>STDDEV_SAMP</code> function.
* <p>
* Calculate the sample standard deviation. This standard SQL function may be supported
* natively, or emulated using {@link #sum(Field)} and {@link #count(Field)}. If an
* emulation is applied, beware of the risk of "<a href="https://en.wikipedia.org/wiki/Catastrophic_cancellation">Catastrophic
* cancellation</a>" in case the calculations are performed using floating point arithmetic.
*/
@NotNull
@Support({ CUBRID, H2, HSQLDB, MARIADB, MYSQL, POSTGRES })
@ -18855,15 +18925,25 @@ public class DSL {
/**
* The <code>VAR_POP</code> function.
* <p>
* Calculate the population variance. This standard SQL function may be supported natively,
* or emulated using {@link #sum(Field)} and {@link #count(Field)}. If an emulation
* is applied, beware of the risk of "<a href="https://en.wikipedia.org/wiki/Catastrophic_cancellation">Catastrophic
* cancellation</a>" in case the calculations are performed using floating point arithmetic.
*/
@NotNull
@Support({ CUBRID, H2, HSQLDB, MARIADB, MYSQL, POSTGRES })
@Support
public static AggregateFunction<BigDecimal> varPop(Field<? extends Number> field) {
return new VarPop(field);
}
/**
* The <code>VAR_SAMP</code> function.
* <p>
* Calculate the sample variance. This standard SQL function may be supported natively,
* or emulated using {@link #sum(Field)} and {@link #count(Field)}. If an emulation
* is applied, beware of the risk of "<a href="https://en.wikipedia.org/wiki/Catastrophic_cancellation">Catastrophic
* cancellation</a>" in case the calculations are performed using floating point arithmetic.
*/
@NotNull
@Support({ CUBRID, H2, HSQLDB, MARIADB, MYSQL, POSTGRES })

View File

@ -10696,47 +10696,51 @@ final class DefaultParseContext extends AbstractScope implements ParseContext {
}
private final AggregateFunction<?> parseBinarySetFunctionIf() {
Field<? extends Number> arg1;
Field<? extends Number> arg2;
BinarySetFunctionType type = parseBinarySetFunctionTypeIf();
switch (characterUpper()) {
case 'C':
if (parseFunctionNameIf("CORR"))
return parseBindarySetFunction(DSL::corr);
else if (parseFunctionNameIf("COVAR_POP"))
return parseBindarySetFunction(DSL::covarPop);
else if (parseFunctionNameIf("COVAR_SAMP"))
return parseBindarySetFunction(DSL::covarSamp);
if (type == null)
return null;
break;
case 'R':
if (parseFunctionNameIf("REGR_AVGX"))
return parseBindarySetFunction(DSL::regrAvgX);
else if (parseFunctionNameIf("REGR_AVGY"))
return parseBindarySetFunction(DSL::regrAvgY);
else if (parseFunctionNameIf("REGR_COUNT"))
return parseBindarySetFunction(DSL::regrCount);
else if (parseFunctionNameIf("REGR_INTERCEPT"))
return parseBindarySetFunction(DSL::regrIntercept);
else if (parseFunctionNameIf("REGR_R2"))
return parseBindarySetFunction(DSL::regrR2);
else if (parseFunctionNameIf("REGR_SLOPE"))
return parseBindarySetFunction(DSL::regrSlope);
else if (parseFunctionNameIf("REGR_SXX"))
return parseBindarySetFunction(DSL::regrSXX);
else if (parseFunctionNameIf("REGR_SXY"))
return parseBindarySetFunction(DSL::regrSXY);
else if (parseFunctionNameIf("REGR_SYY"))
return parseBindarySetFunction(DSL::regrSYY);
break;
}
return null;
}
private final AggregateFunction<?> parseBindarySetFunction(BiFunction<? super Field<? extends Number>, ? super Field<? extends Number>, ? extends AggregateFunction<?>> function) {
parse('(');
arg1 = (Field) toField(parseNumericOp(N));
Field<? extends Number> arg1 = (Field) parseField();
parse(',');
arg2 = (Field) toField(parseNumericOp(N));
Field<? extends Number> arg2 = (Field) parseField();
parse(')');
switch (type) {
case CORR:
return corr(arg1, arg2);
case COVAR_POP:
return covarPop(arg1, arg2);
case COVAR_SAMP:
return covarSamp(arg1, arg2);
case REGR_AVGX:
return regrAvgX(arg1, arg2);
case REGR_AVGY:
return regrAvgY(arg1, arg2);
case REGR_COUNT:
return regrCount(arg1, arg2);
case REGR_INTERCEPT:
return regrIntercept(arg1, arg2);
case REGR_R2:
return regrR2(arg1, arg2);
case REGR_SLOPE:
return regrSlope(arg1, arg2);
case REGR_SXX:
return regrSXX(arg1, arg2);
case REGR_SXY:
return regrSXY(arg1, arg2);
case REGR_SYY:
return regrSYY(arg1, arg2);
default:
throw exception("Binary set function not supported: " + type);
}
return function.apply(arg1, arg2);
}
private final AggregateFilterStep<?> parseOrderedSetFunctionIf() {
@ -12432,16 +12436,6 @@ final class DefaultParseContext extends AbstractScope implements ParseContext {
return null;
}
private final BinarySetFunctionType parseBinarySetFunctionTypeIf() {
// TODO speed this up
for (BinarySetFunctionType type : BinarySetFunctionType.values())
if (parseFunctionNameIf(type.name()))
return type;
return null;
}
private final Comparator parseComparatorIf() {
if (parseIf("="))
return Comparator.EQUALS;
@ -13014,21 +13008,6 @@ final class DefaultParseContext extends AbstractScope implements ParseContext {
// INTERSECTION;
}
private static enum BinarySetFunctionType {
CORR,
COVAR_POP,
COVAR_SAMP,
REGR_SLOPE,
REGR_INTERCEPT,
REGR_COUNT,
REGR_R2,
REGR_AVGX,
REGR_AVGY,
REGR_SXX,
REGR_SYY,
REGR_SXY,
}
private static final String[] KEYWORDS_IN_STATEMENTS = {
"ALTER",
"BEGIN",

View File

@ -96,7 +96,7 @@ extends
Field<? extends Number> x = (Field) getArguments().get(0);
Field<? extends Number> y = (Field) getArguments().get(1);
ctx.visit(fo(DSL.avg(DSL.nvl2(x, y, DSL.NULL(NUMERIC)).cast(NUMERIC))));
ctx.visit(fo(DSL.avg(DSL.nvl2(x, y, DSL.NULL(d(ctx))).cast(d(ctx)))));
}
else
super.accept(ctx);

View File

@ -96,7 +96,7 @@ extends
Field<? extends Number> x = (Field) getArguments().get(0);
Field<? extends Number> y = (Field) getArguments().get(1);
ctx.visit(fo(DSL.avg(DSL.nvl2(y, x, DSL.NULL(NUMERIC)).cast(NUMERIC))));
ctx.visit(fo(DSL.avg(DSL.nvl2(y, x, DSL.NULL(d(ctx))).cast(d(ctx)))));
}
else
super.accept(ctx);

View File

@ -96,7 +96,7 @@ extends
Field<? extends Number> x = (Field) getArguments().get(0);
Field<? extends Number> y = (Field) getArguments().get(1);
ctx.visit(fo(DSL.avg(x(x, y).cast(NUMERIC))).minus(fo(DSL.regrSlope(x, y)).times(fo(DSL.avg(y(x, y).cast(NUMERIC))))));
ctx.visit(fo(DSL.avg(x(x, y).cast(d(ctx)))).minus(fo(DSL.regrSlope(x, y)).times(fo(DSL.avg(y(x, y).cast(d(ctx)))))));
}
else
super.accept(ctx);

View File

@ -97,7 +97,7 @@ extends
Field<? extends Number> y = (Field) getArguments().get(1);
ctx.visit(DSL
.when(fo(DSL.varPop(y(x, y))).eq(inline(BigDecimal.ZERO)), DSL.NULL(NUMERIC))
.when(fo(DSL.varPop(y(x, y))).eq(inline(BigDecimal.ZERO)), (Field) DSL.NULL(d(ctx)))
.when(fo(DSL.varPop(x(x, y))).eq(inline(BigDecimal.ZERO)), inline(BigDecimal.ONE))
.else_(DSL.square(fo(DSL.corr(x, y))))
);

View File

@ -85,6 +85,20 @@ extends
private static final Set<SQLDialect> NO_SUPPORT_NATIVE = SQLDialect.supportedUntil(DERBY, IGNITE, SQLITE);
@SuppressWarnings("unchecked")
@Override
public void accept(Context<?> ctx) {
if (NO_SUPPORT_NATIVE.contains(ctx.dialect())) {
Field<? extends Number> x = (Field) getArguments().get(0);
ctx.visit(fo(DSL.avg(DSL.square(x))).minus(DSL.square(fo(DSL.avg(x)))));
}
else
super.accept(ctx);
}
@Override
void acceptFunctionName(Context<?> ctx) {
switch (ctx.family()) {