[FLINK-34878][cdc][transform] Flink CDC pipeline transform supports CASE WHEN (#3228)

pull/3281/head
Wink 11 months ago committed by Qingsheng Ren
parent e022f4d259
commit 7a1947927d

@ -501,4 +501,13 @@ public class SystemFunctionUtils {
public static boolean valueEquals(Object object1, Object object2) {
return (object1 != null && object2 != null) && object1.equals(object2);
}
public static Object coalesce(Object... objects) {
for (Object item : objects) {
if (item != null) {
return item;
}
}
return null;
}
}

@ -27,6 +27,7 @@ import org.apache.calcite.sql.SqlIdentifier;
import org.apache.calcite.sql.SqlLiteral;
import org.apache.calcite.sql.SqlNode;
import org.apache.calcite.sql.SqlNodeList;
import org.apache.calcite.sql.fun.SqlCase;
import org.apache.calcite.sql.type.SqlTypeName;
import org.codehaus.commons.compiler.CompileException;
import org.codehaus.commons.compiler.Location;
@ -82,17 +83,51 @@ public class JaninoCompiler {
}
public static String translateSqlNodeToJaninoExpression(SqlNode transform) {
if (transform instanceof SqlIdentifier) {
SqlIdentifier sqlIdentifier = (SqlIdentifier) transform;
return sqlIdentifier.names.get(sqlIdentifier.names.size() - 1);
} else if (transform instanceof SqlBasicCall) {
Java.Rvalue rvalue = translateJaninoAST((SqlBasicCall) transform);
Java.Rvalue rvalue = translateSqlNodeToJaninoRvalue(transform);
if (rvalue != null) {
return rvalue.toString();
}
return "";
}
private static Java.Rvalue translateJaninoAST(SqlBasicCall sqlBasicCall) {
public static Java.Rvalue translateSqlNodeToJaninoRvalue(SqlNode transform) {
if (transform instanceof SqlIdentifier) {
return translateSqlIdentifier((SqlIdentifier) transform);
} else if (transform instanceof SqlBasicCall) {
return translateSqlBasicCall((SqlBasicCall) transform);
} else if (transform instanceof SqlCase) {
return translateSqlCase((SqlCase) transform);
} else if (transform instanceof SqlLiteral) {
return translateSqlSqlLiteral((SqlLiteral) transform);
}
return null;
}
private static Java.Rvalue translateSqlIdentifier(SqlIdentifier sqlIdentifier) {
String columnName = sqlIdentifier.names.get(sqlIdentifier.names.size() - 1);
if (NO_OPERAND_TIMESTAMP_FUNCTIONS.contains(columnName)) {
return generateNoOperandTimestampFunctionOperation(columnName);
} else {
return new Java.AmbiguousName(Location.NOWHERE, new String[] {columnName});
}
}
private static Java.Rvalue translateSqlSqlLiteral(SqlLiteral sqlLiteral) {
if (sqlLiteral.getValue() == null) {
return new Java.NullLiteral(Location.NOWHERE);
}
String value = sqlLiteral.getValue().toString();
if (sqlLiteral instanceof SqlCharStringLiteral) {
// Double quotation marks represent strings in Janino.
value = "\"" + value.substring(1, value.length() - 1) + "\"";
}
if (SQL_TYPE_NAME_IGNORE.contains(sqlLiteral.getTypeName())) {
value = "\"" + value + "\"";
}
return new Java.AmbiguousName(Location.NOWHERE, new String[] {value});
}
private static Java.Rvalue translateSqlBasicCall(SqlBasicCall sqlBasicCall) {
List<SqlNode> operandList = sqlBasicCall.getOperandList();
List<Java.Rvalue> atoms = new ArrayList<>();
for (SqlNode sqlNode : operandList) {
@ -105,32 +140,44 @@ public class JaninoCompiler {
return sqlBasicCallToJaninoRvalue(sqlBasicCall, atoms.toArray(new Java.Rvalue[0]));
}
private static Java.Rvalue translateSqlCase(SqlCase sqlCase) {
SqlNodeList whenOperands = sqlCase.getWhenOperands();
SqlNodeList thenOperands = sqlCase.getThenOperands();
SqlNode elseOperand = sqlCase.getElseOperand();
List<Java.Rvalue> whenAtoms = new ArrayList<>();
for (SqlNode sqlNode : whenOperands) {
translateSqlNodeToAtoms(sqlNode, whenAtoms);
}
List<Java.Rvalue> thenAtoms = new ArrayList<>();
for (SqlNode sqlNode : thenOperands) {
translateSqlNodeToAtoms(sqlNode, thenAtoms);
}
Java.Rvalue elseAtoms = translateSqlNodeToJaninoRvalue(elseOperand);
Java.Rvalue sqlCaseRvalueTemp = elseAtoms;
for (int i = whenAtoms.size() - 1; i >= 0; i--) {
sqlCaseRvalueTemp =
new Java.ConditionalExpression(
Location.NOWHERE,
whenAtoms.get(i),
thenAtoms.get(i),
sqlCaseRvalueTemp);
}
return new Java.ParenthesizedExpression(Location.NOWHERE, sqlCaseRvalueTemp);
}
private static void translateSqlNodeToAtoms(SqlNode sqlNode, List<Java.Rvalue> atoms) {
if (sqlNode instanceof SqlIdentifier) {
SqlIdentifier sqlIdentifier = (SqlIdentifier) sqlNode;
String columnName = sqlIdentifier.names.get(sqlIdentifier.names.size() - 1);
if (NO_OPERAND_TIMESTAMP_FUNCTIONS.contains(columnName)) {
atoms.add(generateNoOperandTimestampFunctionOperation(columnName));
} else {
atoms.add(new Java.AmbiguousName(Location.NOWHERE, new String[] {columnName}));
}
atoms.add(translateSqlIdentifier((SqlIdentifier) sqlNode));
} else if (sqlNode instanceof SqlLiteral) {
SqlLiteral sqlLiteral = (SqlLiteral) sqlNode;
String value = sqlLiteral.getValue().toString();
if (sqlLiteral instanceof SqlCharStringLiteral) {
// Double quotation marks represent strings in Janino.
value = "\"" + value.substring(1, value.length() - 1) + "\"";
}
if (SQL_TYPE_NAME_IGNORE.contains(sqlLiteral.getTypeName())) {
value = "\"" + value + "\"";
}
atoms.add(new Java.AmbiguousName(Location.NOWHERE, new String[] {value}));
atoms.add(translateSqlSqlLiteral((SqlLiteral) sqlNode));
} else if (sqlNode instanceof SqlBasicCall) {
atoms.add(translateJaninoAST((SqlBasicCall) sqlNode));
atoms.add(translateSqlBasicCall((SqlBasicCall) sqlNode));
} else if (sqlNode instanceof SqlNodeList) {
for (SqlNode node : (SqlNodeList) sqlNode) {
translateSqlNodeToAtoms(node, atoms);
}
} else if (sqlNode instanceof SqlCase) {
atoms.add(translateSqlCase((SqlCase) sqlNode));
}
}

@ -44,7 +44,9 @@ import org.apache.calcite.sql.SqlBasicCall;
import org.apache.calcite.sql.SqlIdentifier;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.SqlNode;
import org.apache.calcite.sql.SqlNodeList;
import org.apache.calcite.sql.SqlSelect;
import org.apache.calcite.sql.fun.SqlCase;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.parser.SqlParseException;
import org.apache.calcite.sql.parser.SqlParser;
@ -250,10 +252,7 @@ public class TransformParser {
return "";
}
SqlNode where = sqlSelect.getWhere();
if (!(where instanceof SqlBasicCall)) {
throw new ParseException("Unrecognized where: " + where.toString());
}
return JaninoCompiler.translateSqlNodeToJaninoExpression((SqlBasicCall) where);
return JaninoCompiler.translateSqlNodeToJaninoExpression(where);
}
public static List<String> parseComputedColumnNames(String projection) {
@ -307,11 +306,7 @@ public class TransformParser {
return new ArrayList<>();
}
SqlNode where = sqlSelect.getWhere();
if (!(where instanceof SqlBasicCall)) {
throw new ParseException("Unrecognized where: " + where.toString());
}
SqlBasicCall sqlBasicCall = (SqlBasicCall) where;
return parseColumnNameList(sqlBasicCall);
return parseColumnNameList(where);
}
private static List<String> parseColumnNameList(SqlNode sqlNode) {
@ -323,6 +318,9 @@ public class TransformParser {
} else if (sqlNode instanceof SqlBasicCall) {
SqlBasicCall sqlBasicCall = (SqlBasicCall) sqlNode;
findSqlIdentifier(sqlBasicCall.getOperandList(), columnNameList);
} else if (sqlNode instanceof SqlCase) {
SqlCase sqlCase = (SqlCase) sqlNode;
findSqlIdentifier(sqlCase.getWhenOperands().getList(), columnNameList);
}
return columnNameList;
}
@ -336,6 +334,10 @@ public class TransformParser {
} else if (sqlNode instanceof SqlBasicCall) {
SqlBasicCall sqlBasicCall = (SqlBasicCall) sqlNode;
findSqlIdentifier(sqlBasicCall.getOperandList(), columnNameList);
} else if (sqlNode instanceof SqlCase) {
SqlCase sqlCase = (SqlCase) sqlNode;
SqlNodeList whenOperands = sqlCase.getWhenOperands();
findSqlIdentifier(whenOperands.getList(), columnNameList);
}
}
}

@ -578,9 +578,23 @@ public class TransformDataOperatorTest {
testExpressionConditionTransform("ceil(2.4) = 3.0");
testExpressionConditionTransform("floor(2.5) = 2.0");
testExpressionConditionTransform("round(3.1415926,2) = 3.14");
testExpressionConditionTransform("IF(2>0,1,0) = 1");
testExpressionConditionTransform("COALESCE(null,1,2) = 1");
testExpressionConditionTransform("1 + 1 = 2");
testExpressionConditionTransform("1 - 1 = 0");
testExpressionConditionTransform("1 * 1 = 1");
testExpressionConditionTransform("3 % 2 = 1");
testExpressionConditionTransform("1 < 2");
testExpressionConditionTransform("1 <= 1");
testExpressionConditionTransform("1 > 0");
testExpressionConditionTransform("1 >= 1");
testExpressionConditionTransform(
"case 1 when 1 then 'a' when 2 then 'b' else 'c' end = 'a'");
testExpressionConditionTransform("case col1 when '1' then true else false end");
testExpressionConditionTransform("case when col1 = '1' then true else false end");
}
void testExpressionConditionTransform(String expression) throws Exception {
private void testExpressionConditionTransform(String expression) throws Exception {
TransformDataOperator transform =
TransformDataOperator.newBuilder()
.addTransform(

@ -21,8 +21,6 @@ import org.apache.flink.cdc.common.schema.Schema;
import org.apache.flink.cdc.common.types.DataTypes;
import org.apache.flink.cdc.runtime.parser.metadata.TransformSchemaFactory;
import org.apache.flink.cdc.runtime.parser.metadata.TransformSqlOperatorTable;
import org.apache.flink.table.api.ApiExpression;
import org.apache.flink.table.api.Expressions;
import org.apache.calcite.config.CalciteConnectionConfigImpl;
import org.apache.calcite.jdbc.CalciteSchema;
@ -260,13 +258,12 @@ public class TransformParserTest {
testFilterExpression("upper(lower(id))", "upper(lower(id))");
testFilterExpression(
"abs(uniq_id) > 10 and id is not null", "abs(uniq_id) > 10 && null != id");
}
@Test
public void testSqlCall() {
ApiExpression apiExpression = Expressions.concat("1", "2");
ApiExpression substring = apiExpression.substring(1);
System.out.println(substring);
testFilterExpression(
"case id when 1 then 'a' when 2 then 'b' else 'c' end",
"(valueEquals(id, 1) ? \"a\" : valueEquals(id, 2) ? \"b\" : \"c\")");
testFilterExpression(
"case when id = 1 then 'a' when id = 2 then 'b' else 'c' end",
"(valueEquals(id, 1) ? \"a\" : valueEquals(id, 2) ? \"b\" : \"c\")");
}
private void testFilterExpression(String expression, String expressionExpect) {

Loading…
Cancel
Save