diff --git a/kyuubi-server/src/main/scala/org/apache/kyuubi/server/mysql/MySQLCommandPackets.scala b/kyuubi-server/src/main/scala/org/apache/kyuubi/server/mysql/MySQLCommandPackets.scala new file mode 100644 index 000000000..9160f0b80 --- /dev/null +++ b/kyuubi-server/src/main/scala/org/apache/kyuubi/server/mysql/MySQLCommandPackets.scala @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.server.mysql + +import io.netty.buffer.ByteBuf + +import org.apache.kyuubi.server.mysql.MySQLRichByteBuf.Implicit +import org.apache.kyuubi.server.mysql.constant.MySQLCommandPacketType + +sealed abstract class MySQLCommandPacket( + cmdType: MySQLCommandPacketType +) extends MySQLPacket { + override def sequenceId: Int = 0 +} + +case class MySQLComPingPacket() + extends MySQLCommandPacket(MySQLCommandPacketType.COM_PING) + +case class MySQLComQuitPacket() + extends MySQLCommandPacket(MySQLCommandPacketType.COM_QUIT) + +object MySQLComInitDbPacket extends SupportsDecode[MySQLComInitDbPacket] { + override def decode(payload: ByteBuf): MySQLComInitDbPacket = { + val schema = payload.readStringEOF + MySQLComInitDbPacket(schema) + } +} +case class MySQLComInitDbPacket( + database: String +) extends MySQLCommandPacket(MySQLCommandPacketType.COM_INIT_DB) + +object MySQLComFieldListPacket extends SupportsDecode[MySQLComFieldListPacket] { + override def decode(payload: ByteBuf): MySQLComFieldListPacket = { + val table = payload.readStringNul + val fieldWildcard = payload.readStringEOF + MySQLComFieldListPacket(table, fieldWildcard) + } +} + +case class MySQLComFieldListPacket( + table: String, + fieldWildcard: String +) extends MySQLCommandPacket(MySQLCommandPacketType.COM_FIELD_LIST) + +object MySQLComQueryPacket extends SupportsDecode[MySQLComQueryPacket] { + override def decode(payload: ByteBuf): MySQLComQueryPacket = { + val sql = payload.readStringEOF + MySQLComQueryPacket(sql) + } +} + +case class MySQLComQueryPacket( + sql: String +) extends MySQLCommandPacket(MySQLCommandPacketType.COM_QUERY) + +case class MySQLUnsupportedCommandPacket( + cmdType: MySQLCommandPacketType +) extends MySQLCommandPacket(cmdType) diff --git a/kyuubi-server/src/main/scala/org/apache/kyuubi/server/mysql/MySQLDataPackets.scala b/kyuubi-server/src/main/scala/org/apache/kyuubi/server/mysql/MySQLDataPackets.scala new file mode 100644 index 000000000..624b691c6 --- /dev/null +++ b/kyuubi-server/src/main/scala/org/apache/kyuubi/server/mysql/MySQLDataPackets.scala @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.server.mysql + +import java.lang.{Boolean => JBoolean} +import java.math.BigDecimal +import java.sql.Timestamp +import java.time.LocalDateTime + +import io.netty.buffer.ByteBuf + +import org.apache.kyuubi.server.mysql.MySQLDateTimeUtils._ +import org.apache.kyuubi.server.mysql.MySQLRichByteBuf.Implicit +import org.apache.kyuubi.server.mysql.constant.{MySQLDataType, MySQLServerDefines} + +case class MySQLFieldCountPacket( + sequenceId: Int, + columnCount: Int +) extends MySQLPacket with SupportsEncode { + + override def encode(payload: ByteBuf): Unit = { + payload.writeIntLenenc(columnCount) + } +} + +case class MySQLColumnDefinition41Packet( + sequenceId: Int, + flags: Int, + name: String, + columnLength: Int, + columnType: MySQLDataType, + decimals: Int +) extends MySQLPacket with SupportsEncode { + + def nextLength: Int = 0x0c + + def characterSet: Int = MySQLServerDefines.CHARSET + + def catalog: String = "" + + def database: String = "" + + def table: String = "" + + def originalTable: String = "" + + def originalName: String = "" + + def containDefaultValues: Boolean = false + + override def encode(payload: ByteBuf): Unit = { + payload.writeStringLenenc(catalog) + payload.writeStringLenenc(database) + payload.writeStringLenenc(table) + payload.writeStringLenenc(originalTable) + payload.writeStringLenenc(name) + payload.writeStringLenenc(originalName) + payload.writeIntLenenc(nextLength) + payload.writeInt2(characterSet) + payload.writeInt4(columnLength) + payload.writeInt1(columnType.value) + payload.writeInt2(flags) + payload.writeInt1(decimals) + payload.writeReserved(2) + if (containDefaultValues) { + payload.writeIntLenenc(0) + payload.writeStringLenenc("") + } + } +} + +case class MySQLTextResultSetRowPacket( + sequenceId: Int, + row: Seq[Any] +) extends MySQLPacket with SupportsEncode { + + private def nullVal = 0xfb + + override def encode(payload: ByteBuf): Unit = { + row.foreach { + case null => payload.writeInt1(nullVal) + // TODO check all possible data types returned from backend service + case bytes: Array[Byte] => payload.writeBytesLenenc(bytes) + case ts: Timestamp if ts.getNanos == 0 => + payload.writeStringLenenc(ts.toString.split("\\.")(0)) + case decimal: BigDecimal => payload.writeStringLenenc(decimal.toPlainString) + case JBoolean.TRUE | true => payload.writeBytesLenenc(Array[Byte]('1')) + case JBoolean.FALSE | false => payload.writeBytesLenenc(Array[Byte]('0')) + case time: LocalDateTime => payload.writeStringLenenc(dtFmt.format(time)) + case other => payload.writeStringLenenc(other.toString) + } + } +} diff --git a/kyuubi-server/src/main/scala/org/apache/kyuubi/server/mysql/MySQLDateTimeUtils.scala b/kyuubi-server/src/main/scala/org/apache/kyuubi/server/mysql/MySQLDateTimeUtils.scala new file mode 100644 index 000000000..89409bc0b --- /dev/null +++ b/kyuubi-server/src/main/scala/org/apache/kyuubi/server/mysql/MySQLDateTimeUtils.scala @@ -0,0 +1,24 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.server.mysql + +import java.time.format.DateTimeFormatter + +object MySQLDateTimeUtils { + val dtFmt: DateTimeFormatter = DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm:ss") +} diff --git a/kyuubi-server/src/main/scala/org/apache/kyuubi/server/mysql/MySQLGenericPackets.scala b/kyuubi-server/src/main/scala/org/apache/kyuubi/server/mysql/MySQLGenericPackets.scala new file mode 100644 index 000000000..4e9e40ef6 --- /dev/null +++ b/kyuubi-server/src/main/scala/org/apache/kyuubi/server/mysql/MySQLGenericPackets.scala @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.server.mysql + +import java.sql.SQLException + +import io.netty.buffer.ByteBuf + +import org.apache.kyuubi.KyuubiSQLException +import org.apache.kyuubi.server.mysql.MySQLRichByteBuf.Implicit +import org.apache.kyuubi.server.mysql.constant._ + +case class MySQLOKPacket( + sequenceId: Int = 0, + affectedRows: Long = 0L, + lastInsertId: Long = 0L +) extends MySQLPacket with SupportsEncode { + + def header: Int = 0x00 + + def statusFlag: MySQLStatusFlag = MySQLStatusFlag.SERVER_STATUS_AUTOCOMMIT + + def warnings: Int = 0 + + def info: String = "" + + override def encode(payload: ByteBuf): Unit = { + payload.writeInt1(header) + payload.writeIntLenenc(affectedRows) + payload.writeIntLenenc(lastInsertId) + payload.writeInt2(statusFlag.value) + payload.writeInt2(warnings) + payload.writeStringEOF(info) + } +} + +object MySQLErrPacket { + def apply(cause: Throwable): MySQLErrPacket = { + cause match { + case kse: KyuubiSQLException if kse.getCause != null => + // prefer brief nested error message instead of whole stacktrace + apply(kse.getCause) + case e: Exception if e.getMessage contains "NoSuchDatabaseException" => + MySQLErrPacket(1, MySQLErrorCode.ER_BAD_DB_ERROR, cause.getMessage) + case se: SQLException if se.getSQLState == null => + MySQLErrPacket(1, MySQLErrorCode.ER_INTERNAL_ERROR, cause.getMessage) + case se: SQLException => + MySQLErrPacket(1, MySQLErrorCode(se.getErrorCode, se.getSQLState, se.getMessage)) + case _ => + MySQLErrPacket(1, MySQLErrorCode.UNKNOWN_EXCEPTION, cause.getMessage) + } + } +} + +case class MySQLErrPacket( + sequenceId: Int, + sqlErrorCode: MySQLErrorCode, + errMsgArgs: String* +) extends MySQLPacket with SupportsEncode { + + def header: Int = 0xff + + def sqlStateMarker: String = "#" + + def errorCode: Int = sqlErrorCode.errorCode + + def sqlState: String = sqlErrorCode.sqlState + + def errorMessage: String = sqlErrorCode.errorMessage format (errMsgArgs: _*) + + override def encode(payload: ByteBuf): Unit = { + payload.writeInt1(header) + payload.writeInt2(errorCode) + payload.writeStringFix(sqlStateMarker) + payload.writeStringFix(sqlState) + payload.writeStringEOF(errorMessage) + } +} + +case class MySQLEofPacket( + sequenceId: Int = 0 +) extends MySQLPacket with SupportsEncode { + + def header: Int = 0xfe + + def warnings: Int = 0 + + def statusFlags: MySQLStatusFlag = MySQLStatusFlag.SERVER_STATUS_AUTOCOMMIT + + override def encode(payload: ByteBuf): Unit = { + payload.writeInt1(header) + payload.writeInt2(warnings) + payload.writeInt2(statusFlags.value) + } +} diff --git a/kyuubi-server/src/main/scala/org/apache/kyuubi/server/mysql/MySQLPacket.scala b/kyuubi-server/src/main/scala/org/apache/kyuubi/server/mysql/MySQLPacket.scala new file mode 100644 index 000000000..461a42976 --- /dev/null +++ b/kyuubi-server/src/main/scala/org/apache/kyuubi/server/mysql/MySQLPacket.scala @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.server.mysql + +import io.netty.buffer.ByteBuf + +/** + * The required `MySQLPacket`s are split into 4 groups, which are GENERIC, AUTHENTICATION, + * COMMAND and DATA. + *

+ * As Kyuubi is designed to be a "MySQL Server", only part of packets need to be encodable + * and others just need to be decodable. + */ +trait MySQLPacket { + def sequenceId: Int +} + +trait SupportsEncode extends MySQLPacket { + def encode(payload: ByteBuf): Unit +} + +trait SupportsDecode[T <: MySQLPacket] { + def decode(payload: ByteBuf): T +} diff --git a/kyuubi-server/src/main/scala/org/apache/kyuubi/server/mysql/MySQLRichByteBuf.scala b/kyuubi-server/src/main/scala/org/apache/kyuubi/server/mysql/MySQLRichByteBuf.scala new file mode 100644 index 000000000..70a59bb4e --- /dev/null +++ b/kyuubi-server/src/main/scala/org/apache/kyuubi/server/mysql/MySQLRichByteBuf.scala @@ -0,0 +1,328 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.server.mysql + +import java.nio.charset.StandardCharsets + +import io.netty.buffer.ByteBuf + +// https://dev.mysql.com/doc/internals/en/integer.html#packet-Protocol +object MySQLRichByteBuf { + + private def charset = StandardCharsets.UTF_8 + + implicit class Implicit(self: ByteBuf) { + + /** + * Read 1 byte fixed length integer from byte buffers. + * + * @return 1 byte fixed length integer + */ + def readInt1: Int = self.readUnsignedByte + + /** + * Write 1 byte fixed length integer to byte buffers. + * + * @param value 1 byte fixed length integer + */ + def writeInt1(value: Int): ByteBuf = self.writeByte(value) + + /** + * Read 2 byte fixed length integer from byte buffers. + * + * @return 2 byte fixed length integer + */ + def readInt2: Int = self.readUnsignedShortLE + + /** + * Write 2 byte fixed length integer to byte buffers. + * + * @param value 2 byte fixed length integer + */ + def writeInt2(value: Int): ByteBuf = self.writeShortLE(value) + + /** + * Read 3 byte fixed length integer from byte buffers. + * + * @return 3 byte fixed length integer + */ + def readInt3: Int = self.readUnsignedMediumLE + + /** + * Write 3 byte fixed length integer to byte buffers. + * + * @param value 3 byte fixed length integer + */ + def writeInt3(value: Int): ByteBuf = self.writeMediumLE(value) + + /** + * Read 4 byte fixed length integer from byte buffers. + * + * @return 4 byte fixed length integer + */ + def readInt4: Int = self.readIntLE + + /** + * Write 4 byte fixed length integer to byte buffers. + * + * @param value 4 byte fixed length integer + */ + def writeInt4(value: Int): ByteBuf = self.writeIntLE(value) + + /** + * Read 6 byte fixed length integer from byte buffers. + * + * @return 6 byte fixed length integer + */ + def readInt6: Long = { + var result = 0L + var i = 0 + while (i < 6) { + result |= (0xff & self.readByte).toLong << (8 * i) + i = i + 1 + } + result + } + + /** + * Write 6 byte fixed length integer to byte buffers. + * + * @param value 6 byte fixed length integer + */ + def writeInt6(value: Long): ByteBuf = throw new UnsupportedOperationException + + /** + * Read 8 byte fixed length integer from byte buffers. + * + * @return 8 byte fixed length integer + */ + def readInt8: Long = self.readLongLE + + /** + * Write 8 byte fixed length integer to byte buffers. + * + * @param value 8 byte fixed length integer + */ + def writeInt8(value: Long): ByteBuf = self.writeLongLE(value) + + /** + * Read lenenc integer from byte buffers. + * + * @return lenenc integer + */ + def readIntLenenc: Long = { + val firstByte = readInt1 + if (firstByte < 0xfb) return firstByte + if (0xfb == firstByte) return 0 + if (0xfc == firstByte) return readInt2 + if (0xfd == firstByte) return readInt3 + self.readLongLE + } + + /** + * Write lenenc integer to byte buffers. + * + * @param value lenenc integer + */ + def writeIntLenenc(value: Long): ByteBuf = { + if (value < 0xfb) { + self.writeByte(value.toInt) + } else if (value < (1 << 16)) { + self.writeByte(0xfc) + self.writeShortLE(value.toInt) + } else if (value < (1 << 24)) { + self.writeByte(0xfd) + self.writeMediumLE(value.toInt) + } else { + self.writeByte(0xfe) + self.writeLongLE(value) + } + } + + /** + * Read fixed length long from byte buffers. + * + * @param length length read from byte buffers + * @return fixed length long + */ + def readLong(length: Int): Long = { + var result = 0L + var i = 0 + while (i < length) { + result = result << 8 | readInt1 + i = i + 1 + } + result + } + + /** + * Read lenenc string from byte buffers. + * + * @return lenenc string + */ + def readStringLenenc: String = { + val length = readIntLenenc.toInt + val result = new Array[Byte](length) + self.readBytes(result) + new String(result, charset) + } + + /** + * Read lenenc string from byte buffers for bytes. + * + * @return lenenc bytes + */ + def readStringLenencByBytes: Array[Byte] = { + val length = readIntLenenc.toInt + val result = new Array[Byte](length) + self.readBytes(result) + result + } + + /** + * Write lenenc string to byte buffers. + * + * @param value fixed length string + */ + def writeStringLenenc(value: String): ByteBuf = { + val bytes = value.getBytes(charset) + writeIntLenenc(bytes.length) + self.writeBytes(bytes) + } + + /** + * Write lenenc bytes to byte buffers. + * + * @param value fixed length bytes + */ + def writeBytesLenenc(value: Array[Byte]): ByteBuf = { + if (0 == value.length) { + self.writeByte(0) + return self + } + writeIntLenenc(value.length) + self.writeBytes(value) + } + + /** + * Read fixed length string from byte buffers. + * + * @param length length of fixed string + * @return fixed length string + */ + def readStringFix(length: Int): String = new String(readStringFixByBytes(length), charset) + + /** + * Read fixed length string from byte buffers and return bytes. + * + * @param length length of fixed string + * @return fixed length bytes + */ + def readStringFixByBytes(length: Int): Array[Byte] = { + val result = new Array[Byte](length) + self.readBytes(result) + result + } + + /** + * Write variable length string to byte buffers. + * + * @param value fixed length string + */ + def writeStringFix(value: String): ByteBuf = self.writeBytes(value.getBytes(charset)) + + /** + * Write variable length bytes to byte buffers. + * + * @param value fixed length bytes + */ + def writeBytes(value: Array[Byte]): ByteBuf = self.writeBytes(value) + + /** + * Read null terminated string from byte buffers. + * + * @return null terminated string + */ + def readStringNul: String = new String(readStringNulByBytes, charset) + + /** + * Read null terminated string from byte buffers and return bytes. + * + * @return null terminated bytes + */ + def readStringNulByBytes: Array[Byte] = { + val result = new Array[Byte](self.bytesBefore(0.toByte)) + self.readBytes(result) + self.skipBytes(1) + result + } + + /** + * Write null terminated string to byte buffers. + * + * @param value null terminated string + */ + def writeStringNul(value: String): ByteBuf = { + self.writeBytes(value.getBytes(charset)) + self.writeByte(0) + } + + /** + * Read rest of packet string from byte buffers and return bytes. + * + * @return rest of packet string bytes + */ + def readStringEOFByBytes: Array[Byte] = { + val result = new Array[Byte](self.readableBytes) + self.readBytes(result) + result + } + + /** + * Read rest of packet string from byte buffers. + * + * @return rest of packet string + */ + def readStringEOF: String = { + val result = new Array[Byte](self.readableBytes) + self.readBytes(result) + new String(result, charset) + } + + /** + * Write rest of packet string to byte buffers. + * + * @param value rest of packet string + */ + def writeStringEOF(value: String): ByteBuf = self.writeBytes(value.getBytes(charset)) + + /** + * Skip reserved from byte buffers. + * + * @param length length of reserved + */ + def skipReserved(length: Int): ByteBuf = self.skipBytes(length) + + /** + * Write null for reserved to byte buffers. + * + * @param length length of reserved + */ + def writeReserved(length: Int): ByteBuf = self.writeZero(length) + } +} diff --git a/kyuubi-server/src/main/scala/org/apache/kyuubi/server/mysql/authentication/MySQLAuthPackets.scala b/kyuubi-server/src/main/scala/org/apache/kyuubi/server/mysql/authentication/MySQLAuthPackets.scala new file mode 100644 index 000000000..0e731bde3 --- /dev/null +++ b/kyuubi-server/src/main/scala/org/apache/kyuubi/server/mysql/authentication/MySQLAuthPackets.scala @@ -0,0 +1,150 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.server.mysql.authentication + +import io.netty.buffer.ByteBuf + +import org.apache.kyuubi.server.mysql._ +import org.apache.kyuubi.server.mysql.MySQLRichByteBuf.Implicit +import org.apache.kyuubi.server.mysql.constant._ + +case class MySQLAuthSwitchRequestPacket( + sequenceId: Int, + authPluginName: String, + authPluginData: MySQLNativePassword.PluginData +) extends MySQLPacket with SupportsEncode { + + def header: Int = 0xfe + + override def encode(payload: ByteBuf): Unit = { + payload.writeInt1(header) + payload.writeStringNul(authPluginName) + payload.writeStringNul(new String(authPluginData.full)) + } +} + +object MySQLAuthSwitchResponsePacket extends SupportsDecode[MySQLAuthSwitchResponsePacket] { + + override def decode(payload: ByteBuf): MySQLAuthSwitchResponsePacket = { + val _sequenceId = payload.readInt1 + val _authPluginResponse = payload.readStringEOFByBytes + MySQLAuthSwitchResponsePacket(_sequenceId, _authPluginResponse) + } +} + +case class MySQLAuthSwitchResponsePacket( + sequenceId: Int, + authPluginResponse: Array[Byte] +) extends MySQLPacket + +case class MySQLHandshakePacket( + connectionId: Int, + authPluginData: MySQLNativePassword.PluginData +) extends MySQLPacket with SupportsEncode { + + def protocolVersion: Int = MySQLServerDefines.PROTOCOL_VERSION + + def serverVersion: String = MySQLServerDefines.MYSQL_KYUUBI_SERVER_VERSION + + def statusFlag: MySQLStatusFlag = MySQLStatusFlag.SERVER_STATUS_AUTOCOMMIT + + def charset: Int = MySQLServerDefines.CHARSET + + def capabilityFlagsLower: Int = MySQLCapabilityFlag.handshakeValueLower + + def capabilityFlagsUpper: Int = MySQLCapabilityFlag.handshakeValueUpper + + def authPluginName: String = MySQLAuthenticationMethod.NATIVE_PASSWORD.method + + override def sequenceId: Int = 0 + + override def encode(payload: ByteBuf): Unit = { + payload.writeInt1(protocolVersion) + payload.writeStringNul(serverVersion) + payload.writeInt4(connectionId) + payload.writeStringNul(new String(authPluginData.part1)) + payload.writeInt2(capabilityFlagsLower) + payload.writeInt1(charset) + payload.writeInt2(statusFlag.value) + payload.writeInt2(capabilityFlagsUpper) + payload.writeInt1(if (isClientPluginAuth) authPluginData.full.length + 1 else 0) + payload.writeReserved(10) + if (isClientSecureConnection) payload.writeStringNul(new String(authPluginData.part2)) + if (isClientPluginAuth) payload.writeStringNul(authPluginName) + } + + private def isClientSecureConnection = + (capabilityFlagsLower & MySQLCapabilityFlag.CLIENT_SECURE_CONNECTION.value & 0x00000ffff) != 0 + + private def isClientPluginAuth = + (capabilityFlagsUpper & MySQLCapabilityFlag.CLIENT_PLUGIN_AUTH.value >> 16) != 0 +} + +object MySQLHandshakeResponse41Packet extends SupportsDecode[MySQLHandshakeResponse41Packet] { + override def decode(payload: ByteBuf): MySQLHandshakeResponse41Packet = { + val sequenceId = payload.readInt1 + val capabilityFlags = payload.readInt4 + val maxPacketSize = payload.readInt4 + val characterSet = payload.readInt1 + payload.skipReserved(23) + val username = payload.readStringNul + val authResponse = readAuthResponse(payload, capabilityFlags) + val database = readDatabase(payload, capabilityFlags) + val authPluginName = readAuthPluginName(payload, capabilityFlags) + MySQLHandshakeResponse41Packet( + sequenceId, + capabilityFlags, + maxPacketSize, + characterSet, + username, + authResponse, + database, + authPluginName) + } + + private def readAuthResponse(payload: ByteBuf, capabilityFlags: Int): Array[Byte] = { + if (0 != (capabilityFlags & MySQLCapabilityFlag.CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA.value)) { + return payload.readStringLenencByBytes + } + if (0 != (capabilityFlags & MySQLCapabilityFlag.CLIENT_SECURE_CONNECTION.value)) { + val length = payload.readInt1 + return payload.readStringFixByBytes(length) + } + payload.readStringNulByBytes + } + + private def readDatabase(payload: ByteBuf, capabilityFlags: Int): String = + if (0 != (capabilityFlags & MySQLCapabilityFlag.CLIENT_CONNECT_WITH_DB.value)) { + payload.readStringNul + } else null + + private def readAuthPluginName(payload: ByteBuf, capabilityFlags: Int): String = + if (0 != (capabilityFlags & MySQLCapabilityFlag.CLIENT_PLUGIN_AUTH.value)) payload.readStringNul + else null +} + +case class MySQLHandshakeResponse41Packet( + sequenceId: Int, + capabilityFlags: Int, + maxPacketSize: Int, + characterSet: Int, + username: String, + authResponse: Array[Byte], + database: String, + authPluginName: String +) extends MySQLPacket diff --git a/kyuubi-server/src/main/scala/org/apache/kyuubi/server/mysql/authentication/MySQLAuthentication.scala b/kyuubi-server/src/main/scala/org/apache/kyuubi/server/mysql/authentication/MySQLAuthentication.scala new file mode 100644 index 000000000..25c9f3f7c --- /dev/null +++ b/kyuubi-server/src/main/scala/org/apache/kyuubi/server/mysql/authentication/MySQLAuthentication.scala @@ -0,0 +1,182 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.server.mysql.authentication + +import java.net.InetSocketAddress +import java.util.concurrent.atomic.AtomicInteger + +import scala.util.Random + +import io.netty.buffer.ByteBuf +import io.netty.channel.ChannelHandlerContext + +import org.apache.kyuubi.server.mysql._ +import org.apache.kyuubi.server.mysql.authentication.MySQLAuthentication._ +import org.apache.kyuubi.server.mysql.constant._ + +object MySQLAuthentication { + + private val seed: Array[Byte] = Array( + // format: off + 'a', 'b', 'e', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', + 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', + 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', + 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', + '0', '1', '2', '3', '4', '5', '6', '7', '8', '9') + // format: on + + def randomBytes(length: Int): Array[Byte] = { + val result = new Array[Byte](length) + var i = 0 + while (i < length) { + result(i) = seed(Random.nextInt(seed.length)) + i = i + 1 + } + result + } + + final val connIdCounter = new AtomicInteger +} + +object MySQLConnectionPhase extends Enumeration { + type MySQLConnectionPhase = Value + + val INITIAL_HANDSHAKE, AUTH_PHASE_FAST_PATH, AUTHENTICATION_METHOD_MISMATCH = Value +} + +sealed abstract class MySQLAuthenticationMethod(val method: String) + +object MySQLAuthenticationMethod { + + object OLD_PASSWORD extends MySQLAuthenticationMethod("mysql_old_password") + + // Currently, it's the ONLY supported authentication method + object NATIVE_PASSWORD extends MySQLAuthenticationMethod("mysql_native_password") + + object CLEAR_TEXT extends MySQLAuthenticationMethod("mysql_clear_password") + + object WINDOWS_NATIVE extends MySQLAuthenticationMethod("authentication_windows_client") + + object SHA256 extends MySQLAuthenticationMethod("sha256_password") +} + +case class AuthenticationResult( + user: String, + ip: String, + database: String, + finished: Boolean +) + +object AuthenticationResult { + def finished(username: String, ip: String, database: String): AuthenticationResult = + new AuthenticationResult(username, ip, database, true) + + def continued: AuthenticationResult = + new AuthenticationResult(null, null, null, false) + + def continued(username: String, ip: String, database: String): AuthenticationResult = + new AuthenticationResult(username, ip, database, false) +} + +class MySQLAuthenticationEngine { + private final val authenticator = new MySQLNativePassword + private final val currentSeqId = new AtomicInteger + private var connectionPhase = MySQLConnectionPhase.INITIAL_HANDSHAKE + private var authResponse: Array[Byte] = _ + private var authResult: AuthenticationResult = _ + + def handshake(ctx: ChannelHandlerContext): Int = { + val connectionId = connIdCounter.getAndIncrement + connectionPhase = MySQLConnectionPhase.AUTH_PHASE_FAST_PATH + ctx.writeAndFlush(MySQLHandshakePacket(connectionId, authenticator.pluginData)) + connectionId + } + + def authenticate(ctx: ChannelHandlerContext, buf: ByteBuf): AuthenticationResult = { + connectionPhase match { + case MySQLConnectionPhase.AUTHENTICATION_METHOD_MISMATCH => + authenticationMethodMismatch(buf) + case MySQLConnectionPhase.AUTH_PHASE_FAST_PATH => + authResult = authPhaseFastPath(ctx, buf) + if (!authResult.finished) return authResult + case _ => // never happen + } + val seqId = currentSeqId.incrementAndGet + val responsePacket = authenticator + .login(authResult.user, remoteAddress(ctx), authResponse, authResult.database) + .map(createErrorPacket(ctx, _, seqId)) + .getOrElse(MySQLOKPacket(seqId)) + ctx.writeAndFlush(responsePacket) + + AuthenticationResult.finished(authResult.user, remoteAddress(ctx), authResult.database) + } + + private def authPhaseFastPath(ctx: ChannelHandlerContext, buf: ByteBuf): AuthenticationResult = { + val packet = MySQLHandshakeResponse41Packet.decode(buf) + authResponse = packet.authResponse + currentSeqId.set(packet.sequenceId) + // always switch to mysql_native_password since Kyuubi Server only support this method + if (isClientPluginAuth(packet) + && packet.authPluginName != MySQLAuthenticationMethod.NATIVE_PASSWORD.method) { + connectionPhase = MySQLConnectionPhase.AUTHENTICATION_METHOD_MISMATCH + ctx.writeAndFlush(MySQLAuthSwitchRequestPacket( + currentSeqId.incrementAndGet, + MySQLAuthenticationMethod.NATIVE_PASSWORD.method, + authenticator.pluginData)) + return AuthenticationResult.continued(packet.username, remoteAddress(ctx), packet.database) + } + AuthenticationResult.finished(packet.username, remoteAddress(ctx), packet.database) + } + + private def isClientPluginAuth(packet: MySQLHandshakeResponse41Packet) = + (packet.capabilityFlags & MySQLCapabilityFlag.CLIENT_PLUGIN_AUTH.value) != 0 + + private def authenticationMethodMismatch(buf: ByteBuf): Unit = { + val packet = MySQLAuthSwitchResponsePacket.decode(buf) + currentSeqId.set(packet.sequenceId) + authResponse = packet.authPluginResponse + } + + private def createErrorPacket( + ctx: ChannelHandlerContext, + errorCode: MySQLErrorCode, + seqId: Int + ): MySQLErrPacket = errorCode match { + case MySQLErrorCode.ER_DBACCESS_DENIED_ERROR => MySQLErrPacket( + seqId, + MySQLErrorCode.ER_DBACCESS_DENIED_ERROR, + authResult.user, + remoteAddress(ctx), + authResult.database) + case _ => MySQLErrPacket( + seqId, + MySQLErrorCode.ER_ACCESS_DENIED_ERROR, + authResult.user, + remoteAddress(ctx), + errorMessage) + } + + private def errorMessage = if (authResponse.nonEmpty) "YES" else "NO" + + private def remoteAddress(ctx: ChannelHandlerContext): String = { + ctx.channel.remoteAddress match { + case address: InetSocketAddress => address.getAddress.getHostAddress + case other => other.toString + } + } +} diff --git a/kyuubi-server/src/main/scala/org/apache/kyuubi/server/mysql/authentication/MySQLNativePassword.scala b/kyuubi-server/src/main/scala/org/apache/kyuubi/server/mysql/authentication/MySQLNativePassword.scala new file mode 100644 index 000000000..c23818632 --- /dev/null +++ b/kyuubi-server/src/main/scala/org/apache/kyuubi/server/mysql/authentication/MySQLNativePassword.scala @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.server.mysql.authentication + +import java.util + +import org.apache.commons.codec.digest.DigestUtils + +import org.apache.kyuubi.server.mysql.authentication.MySQLAuthentication.randomBytes +import org.apache.kyuubi.server.mysql.authentication.MySQLNativePassword.PluginData +import org.apache.kyuubi.server.mysql.constant.MySQLErrorCode + +object MySQLNativePassword { + case class PluginData( + part1: Array[Byte] = randomBytes(8), + part2: Array[Byte] = randomBytes(12) + ) { + lazy val full: Array[Byte] = Array.concat(part1, part2) + } +} + +class MySQLNativePassword { + private final val _pluginData = new PluginData + + def pluginData: PluginData = _pluginData + + def login( + user: String, + host: String, + authResp: Array[Byte], + database: String + ): Option[MySQLErrorCode] = { + if (isPasswordRight("kyuubi", authResp)) { + // if (true) { + None + } else { + Some(MySQLErrorCode.ER_ACCESS_DENIED_ERROR) + } + } + + private[authentication] def isPasswordRight(password: String, authentication: Array[Byte]) = + util.Arrays.equals(getAuthCipherBytes(password), authentication) + + private def getAuthCipherBytes(password: String): Array[Byte] = { + val salt = pluginData.full + val passwordSha1 = DigestUtils.sha1(password) + val passwordSha1Sha1 = DigestUtils.sha1(passwordSha1) + val secret = new Array[Byte](salt.length + passwordSha1Sha1.length) + System.arraycopy(salt, 0, secret, 0, salt.length) + System.arraycopy(passwordSha1Sha1, 0, secret, salt.length, passwordSha1Sha1.length) + val secretSha1 = DigestUtils.sha1(secret) + xor(passwordSha1, secretSha1) + } + + private def xor(input: Array[Byte], secret: Array[Byte]): Array[Byte] = + (input zip secret).map { case (b1, b2) => (b1 ^ b2).toByte } +} diff --git a/kyuubi-server/src/main/scala/org/apache/kyuubi/server/mysql/constant/MySQLCapabilityFlag.scala b/kyuubi-server/src/main/scala/org/apache/kyuubi/server/mysql/constant/MySQLCapabilityFlag.scala new file mode 100644 index 000000000..b3af5ea50 --- /dev/null +++ b/kyuubi-server/src/main/scala/org/apache/kyuubi/server/mysql/constant/MySQLCapabilityFlag.scala @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.server.mysql.constant + +sealed abstract class MySQLCapabilityFlag(val value: Int) + +// https://dev.mysql.com/doc/dev/mysql-server/latest/group__group__cs__capabilities__flags.html +object MySQLCapabilityFlag { + + object CLIENT_LONG_PASSWORD extends MySQLCapabilityFlag(0x00000001) + + object CLIENT_FOUND_ROWS extends MySQLCapabilityFlag(0x00000002) + + object CLIENT_LONG_FLAG extends MySQLCapabilityFlag(0x00000004) + + object CLIENT_CONNECT_WITH_DB extends MySQLCapabilityFlag(0x00000008) + + object CLIENT_NO_SCHEMA extends MySQLCapabilityFlag(0x00000010) + + object CLIENT_COMPRESS extends MySQLCapabilityFlag(0x00000020) + + object CLIENT_ODBC extends MySQLCapabilityFlag(0x00000040) + + object CLIENT_LOCAL_FILES extends MySQLCapabilityFlag(0x00000080) + + object CLIENT_IGNORE_SPACE extends MySQLCapabilityFlag(0x00000100) + + object CLIENT_PROTOCOL_41 extends MySQLCapabilityFlag(0x00000200) + + object CLIENT_INTERACTIVE extends MySQLCapabilityFlag(0x00000400) + + object CLIENT_SSL extends MySQLCapabilityFlag(0x00000800) + + object CLIENT_IGNORE_SIGPIPE extends MySQLCapabilityFlag(0x00001000) + + object CLIENT_TRANSACTIONS extends MySQLCapabilityFlag(0x00002000) + + object CLIENT_RESERVED extends MySQLCapabilityFlag(0x00004000) + + object CLIENT_SECURE_CONNECTION extends MySQLCapabilityFlag(0x00008000) + + object CLIENT_MULTI_STATEMENTS extends MySQLCapabilityFlag(0x00010000) + + object CLIENT_MULTI_RESULTS extends MySQLCapabilityFlag(0x00020000) + + object CLIENT_PS_MULTI_RESULTS extends MySQLCapabilityFlag(0x00040000) + + object CLIENT_PLUGIN_AUTH extends MySQLCapabilityFlag(0x00080000) + + object CLIENT_CONNECT_ATTRS extends MySQLCapabilityFlag(0x00100000) + + object CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA extends MySQLCapabilityFlag(0x00200000) + + object CLIENT_CAN_HANDLE_EXPIRED_PASSWORDS extends MySQLCapabilityFlag(0x00400000) + + object CLIENT_SESSION_TRACK extends MySQLCapabilityFlag(0x00800000) + + // since MySQL 5.7.15, disable it for compatible + object CLIENT_DEPRECATE_EOF extends MySQLCapabilityFlag(0x01000000) + + val handshakeValue: Int = calculateValues( + CLIENT_LONG_PASSWORD, + CLIENT_FOUND_ROWS, + CLIENT_LONG_FLAG, + CLIENT_CONNECT_WITH_DB, + CLIENT_ODBC, + CLIENT_IGNORE_SPACE, + CLIENT_PROTOCOL_41, + CLIENT_INTERACTIVE, + CLIENT_IGNORE_SIGPIPE, + CLIENT_TRANSACTIONS, + CLIENT_SECURE_CONNECTION, + CLIENT_PLUGIN_AUTH) + + val handshakeValueLower: Int = handshakeValue & 0x0000ffff + + val handshakeValueUpper: Int = handshakeValue >>> 16 + + private def calculateValues(capabilities: MySQLCapabilityFlag*): Int = + capabilities.foldLeft(0) { case (acc, item) => acc | item.value } +} diff --git a/kyuubi-server/src/main/scala/org/apache/kyuubi/server/mysql/constant/MySQLCommandPacketType.scala b/kyuubi-server/src/main/scala/org/apache/kyuubi/server/mysql/constant/MySQLCommandPacketType.scala new file mode 100644 index 000000000..d0466bb51 --- /dev/null +++ b/kyuubi-server/src/main/scala/org/apache/kyuubi/server/mysql/constant/MySQLCommandPacketType.scala @@ -0,0 +1,154 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.server.mysql.constant + +sealed abstract class MySQLCommandPacketType(val value: Int) + +object MySQLCommandPacketType { + + // https://dev.mysql.com/doc/internals/en/com-sleep.html + object COM_SLEEP extends MySQLCommandPacketType(0x00) + + // https://dev.mysql.com/doc/internals/en/com-quit.html + object COM_QUIT extends MySQLCommandPacketType(0x01) + + // https://dev.mysql.com/doc/internals/en/com-init-db.html + object COM_INIT_DB extends MySQLCommandPacketType(0x02) + + // https://dev.mysql.com/doc/internals/en/com-query.html + object COM_QUERY extends MySQLCommandPacketType(0x03) + + // https://dev.mysql.com/doc/internals/en/com-field-list.html + object COM_FIELD_LIST extends MySQLCommandPacketType(0x04) + + // https://dev.mysql.com/doc/internals/en/com-create-db.html + object COM_CREATE_DB extends MySQLCommandPacketType(0x05) + + // https://dev.mysql.com/doc/internals/en/com-drop-db.html + object COM_DROP_DB extends MySQLCommandPacketType(0x06) + + // https://dev.mysql.com/doc/internals/en/com-refresh.html + object COM_REFRESH extends MySQLCommandPacketType(0x07) + + // https://dev.mysql.com/doc/internals/en/com-shutdown.html + object COM_SHUTDOWN extends MySQLCommandPacketType(0x08) + + // https://dev.mysql.com/doc/internals/en/com-statistics.html + object COM_STATISTICS extends MySQLCommandPacketType(0x09) + + // https://dev.mysql.com/doc/internals/en/com-process-info.html + object COM_PROCESS_INFO extends MySQLCommandPacketType(0x0a) + + // https://dev.mysql.com/doc/internals/en/com-connect.html + object COM_CONNECT extends MySQLCommandPacketType(0x0b) + + // https://dev.mysql.com/doc/internals/en/com-process-kill.html + object COM_PROCESS_KILL extends MySQLCommandPacketType(0x0c) + + // https://dev.mysql.com/doc/internals/en/com-debug.html + object COM_DEBUG extends MySQLCommandPacketType(0x0d) + + // https://dev.mysql.com/doc/internals/en/com-ping.html + object COM_PING extends MySQLCommandPacketType(0x0e) + + // https://dev.mysql.com/doc/internals/en/com-time.html + object COM_TIME extends MySQLCommandPacketType(0x0f) + + // https://dev.mysql.com/doc/internals/en/com-delayed-insert.html + object COM_DELAYED_INSERT extends MySQLCommandPacketType(0x10) + + // https://dev.mysql.com/doc/internals/en/com-change-user.html + object COM_CHANGE_USER extends MySQLCommandPacketType(0x11) + + // https://dev.mysql.com/doc/internals/en/com-binlog-dump.html + object COM_BINLOG_DUMP extends MySQLCommandPacketType(0x12) + + // https://dev.mysql.com/doc/internals/en/com-table-dump.html + object COM_TABLE_DUMP extends MySQLCommandPacketType(0x13) + + // https://dev.mysql.com/doc/internals/en/com-connect-out.html + object COM_CONNECT_OUT extends MySQLCommandPacketType(0x14) + + // https://dev.mysql.com/doc/internals/en/com-register-slave.html + object COM_REGISTER_SLAVE extends MySQLCommandPacketType(0x15) + + // https://dev.mysql.com/doc/internals/en/com-stmt-prepare.html + object COM_STMT_PREPARE extends MySQLCommandPacketType(0x16) + + // https://dev.mysql.com/doc/internals/en/com-stmt-execute.html + object COM_STMT_EXECUTE extends MySQLCommandPacketType(0x17) + + // https://dev.mysql.com/doc/internals/en/com-stmt-send-long-data.html + object COM_STMT_SEND_LONG_DATA extends MySQLCommandPacketType(0x18) + + // https://dev.mysql.com/doc/internals/en/com-stmt-close.html + object COM_STMT_CLOSE extends MySQLCommandPacketType(0x19) + + // https://dev.mysql.com/doc/internals/en/com-stmt-reset.html + object COM_STMT_RESET extends MySQLCommandPacketType(0x1a) + + // https://dev.mysql.com/doc/internals/en/com-set-option.html + object COM_SET_OPTION extends MySQLCommandPacketType(0x1b) + + // https://dev.mysql.com/doc/internals/en/com-stmt-fetch.html + object COM_STMT_FETCH extends MySQLCommandPacketType(0x1c) + + // https://dev.mysql.com/doc/internals/en/com-daemon.html + object COM_DAEMON extends MySQLCommandPacketType(0x1d) + + // https://dev.mysql.com/doc/internals/en/com-binlog-dump-gtid.html + object COM_BINLOG_DUMP_GTID extends MySQLCommandPacketType(0x1e) + + // https://dev.mysql.com/doc/internals/en/com-reset-connection.html + object COM_RESET_CONNECTION extends MySQLCommandPacketType(0x1f) + + def valueOf(value: Int): MySQLCommandPacketType = value match { + case COM_SLEEP.value => COM_SLEEP + case COM_QUIT.value => COM_QUIT + case COM_INIT_DB.value => COM_INIT_DB + case COM_QUERY.value => COM_QUERY + case COM_FIELD_LIST.value => COM_FIELD_LIST + case COM_CREATE_DB.value => COM_CREATE_DB + case COM_DROP_DB.value => COM_DROP_DB + case COM_REFRESH.value => COM_REFRESH + case COM_SHUTDOWN.value => COM_SHUTDOWN + case COM_STATISTICS.value => COM_STATISTICS + case COM_PROCESS_INFO.value => COM_PROCESS_INFO + case COM_CONNECT.value => COM_CONNECT + case COM_PROCESS_KILL.value => COM_PROCESS_KILL + case COM_DEBUG.value => COM_DEBUG + case COM_PING.value => COM_PING + case COM_TIME.value => COM_TIME + case COM_DELAYED_INSERT.value => COM_DELAYED_INSERT + case COM_CHANGE_USER.value => COM_CHANGE_USER + case COM_BINLOG_DUMP.value => COM_BINLOG_DUMP + case COM_TABLE_DUMP.value => COM_TABLE_DUMP + case COM_CONNECT_OUT.value => COM_CONNECT_OUT + case COM_REGISTER_SLAVE.value => COM_REGISTER_SLAVE + case COM_STMT_PREPARE.value => COM_STMT_PREPARE + case COM_STMT_EXECUTE.value => COM_STMT_EXECUTE + case COM_STMT_SEND_LONG_DATA.value => COM_STMT_SEND_LONG_DATA + case COM_STMT_CLOSE.value => COM_STMT_CLOSE + case COM_STMT_RESET.value => COM_STMT_RESET + case COM_SET_OPTION.value => COM_SET_OPTION + case COM_STMT_FETCH.value => COM_STMT_FETCH + case COM_DAEMON.value => COM_DAEMON + case COM_BINLOG_DUMP_GTID.value => COM_BINLOG_DUMP_GTID + case COM_RESET_CONNECTION.value => COM_RESET_CONNECTION + } +} diff --git a/kyuubi-server/src/main/scala/org/apache/kyuubi/server/mysql/constant/MySQLDataType.scala b/kyuubi-server/src/main/scala/org/apache/kyuubi/server/mysql/constant/MySQLDataType.scala new file mode 100644 index 000000000..7a76b47f4 --- /dev/null +++ b/kyuubi-server/src/main/scala/org/apache/kyuubi/server/mysql/constant/MySQLDataType.scala @@ -0,0 +1,182 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.server.mysql.constant + +import java.sql.Types + +import org.apache.hive.service.rpc.thrift.TTypeId + +sealed abstract class MySQLDataType(val value: Int) + +// https://dev.mysql.com/doc/internals/en/com-query-response.html#column-type +object MySQLDataType { + object DECIMAL extends MySQLDataType(0x00) + + object TINY extends MySQLDataType(0x01) + + object SHORT extends MySQLDataType(0x02) + + object LONG extends MySQLDataType(0x03) + + object FLOAT extends MySQLDataType(0x04) + + object DOUBLE extends MySQLDataType(0x05) + + object NULL extends MySQLDataType(0x06) + + object TIMESTAMP extends MySQLDataType(0x07) + + object LONGLONG extends MySQLDataType(0x08) + + object INT24 extends MySQLDataType(0x09) + + object DATE extends MySQLDataType(0x0a) + + object TIME extends MySQLDataType(0x0b) + + object DATETIME extends MySQLDataType(0x0c) + + object YEAR extends MySQLDataType(0x0d) + + // Internal to MySQL Server + object NEWDATE extends MySQLDataType(0x0e) + + object VARCHAR extends MySQLDataType(0x0f) + + object BIT extends MySQLDataType(0x10) + + // Internal to MySQL Server + object TIMESTAMP2 extends MySQLDataType(0x11) + + // Internal to MySQL Server + object DATETIME2 extends MySQLDataType(0x12) + + // Internal to MySQL Server + object TIME2 extends MySQLDataType(0x13) + + // Do not describe in document, but actual exist. + // https://github.com/apache/shardingsphere/issues/4795 + object JSON extends MySQLDataType(0xf5) + + object NEWDECIMAL extends MySQLDataType(0xf6) + + object ENUM extends MySQLDataType(0xf7) + + object SET extends MySQLDataType(0xf8) + + object TINY_BLOB extends MySQLDataType(0xf9) + + object MEDIUM_BLOB extends MySQLDataType(0xfa) + + object LONG_BLOB extends MySQLDataType(0xfb) + + object BLOB extends MySQLDataType(0xfc) + + object VAR_STRING extends MySQLDataType(0xfd) + + object STRING extends MySQLDataType(0xfe) + + object GEOMETRY extends MySQLDataType(0xff) + + def valueOf(value: Int): MySQLDataType = value match { + case 0x00 => DECIMAL + case 0x01 => TINY + case 0x02 => SHORT + case 0x03 => LONG + case 0x04 => FLOAT + case 0x05 => DOUBLE + case 0x06 => NULL + case 0x07 => TIMESTAMP + case 0x08 => LONGLONG + case 0x09 => INT24 + case 0x0a => DATE + case 0x0b => TIME + case 0x0c => DATETIME + case 0x0d => YEAR + case 0x0e => NEWDATE + case 0x0f => VARCHAR + case 0x10 => BIT + case 0x11 => TIMESTAMP2 + case 0x12 => DATETIME2 + case 0x13 => TIME2 + case 0xf5 => JSON + case 0xf6 => NEWDECIMAL + case 0xf7 => ENUM + case 0xf8 => SET + case 0xf9 => TINY_BLOB + case 0xfa => MEDIUM_BLOB + case 0xfb => LONG_BLOB + case 0xfc => BLOB + case 0xfd => VAR_STRING + case 0xfe => STRING + case 0xff => GEOMETRY + case other => throw new IllegalArgumentException( + s"Illegal value $other of MySQLDataType") + } + + def valueOfJdbcType(jdbcValue: Int): MySQLDataType = jdbcValue match { + case Types.BIT => BIT + case Types.TINYINT => TINY + case Types.SMALLINT => SHORT + case Types.INTEGER => LONG + case Types.BIGINT => LONGLONG + case Types.FLOAT => FLOAT + case Types.REAL => FLOAT + case Types.DOUBLE => DOUBLE + case Types.NUMERIC => NEWDECIMAL + case Types.DECIMAL => NEWDECIMAL + case Types.CHAR => STRING + case Types.VARCHAR => VAR_STRING + case Types.LONGVARCHAR => VAR_STRING + case Types.DATE => DATE + case Types.TIME => TIME + case Types.TIMESTAMP => TIMESTAMP + case Types.BINARY => STRING + case Types.VARBINARY => VAR_STRING + case Types.LONGVARBINARY => VAR_STRING + case Types.NULL => NULL + case Types.BLOB => BLOB + case other => throw new IllegalArgumentException( + s"Illegal JDBC type value $other of MySQLDataType") + } + + def valueOfThriftType(tType: TTypeId): MySQLDataType = tType match { + case TTypeId.BOOLEAN_TYPE => TINY + case TTypeId.TINYINT_TYPE => TINY + case TTypeId.SMALLINT_TYPE => SHORT + case TTypeId.INT_TYPE => LONG + case TTypeId.BIGINT_TYPE => LONGLONG + case TTypeId.FLOAT_TYPE => FLOAT + case TTypeId.DOUBLE_TYPE => DOUBLE + case TTypeId.STRING_TYPE => VAR_STRING + case TTypeId.TIMESTAMP_TYPE => TIMESTAMP + case TTypeId.BINARY_TYPE => STRING + case TTypeId.ARRAY_TYPE => VAR_STRING // not exactly match, fallback + case TTypeId.MAP_TYPE => VAR_STRING // not exactly match, fallback + case TTypeId.STRUCT_TYPE => VAR_STRING // not exactly match, fallback + case TTypeId.UNION_TYPE => VAR_STRING // not exactly match, fallback + case TTypeId.USER_DEFINED_TYPE => VAR_STRING // not exactly match, fallback + case TTypeId.DECIMAL_TYPE => NEWDECIMAL + case TTypeId.NULL_TYPE => NULL + case TTypeId.DATE_TYPE => DATE + case TTypeId.VARCHAR_TYPE => VAR_STRING + case TTypeId.CHAR_TYPE => STRING + case TTypeId.INTERVAL_YEAR_MONTH_TYPE => VAR_STRING // not exactly match, fallback + case TTypeId.INTERVAL_DAY_TIME_TYPE => VAR_STRING // not exactly match, fallback + } +} diff --git a/kyuubi-server/src/main/scala/org/apache/kyuubi/server/mysql/constant/MySQLErrorCode.scala b/kyuubi-server/src/main/scala/org/apache/kyuubi/server/mysql/constant/MySQLErrorCode.scala new file mode 100644 index 000000000..a6c35b1be --- /dev/null +++ b/kyuubi-server/src/main/scala/org/apache/kyuubi/server/mysql/constant/MySQLErrorCode.scala @@ -0,0 +1,106 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.server.mysql.constant + +import org.apache.kyuubi.KyuubiSQLException + +case class MySQLErrorCode(errorCode: Int, sqlState: String, errorMessage: String) { + override def toString: String = s"ERROR $errorCode ($sqlState): $errorMessage" + + def toKyuubiSQLException: KyuubiSQLException = { + new KyuubiSQLException(errorMessage, sqlState, errorCode, null) + } +} + +object MySQLErrorCode { + + object TOO_MANY_CONNECTIONS_EXCEPTION extends MySQLErrorCode( + 1040, + "08004", + "Too many connections") + + object RUNTIME_EXCEPTION extends MySQLErrorCode( + 1997, + "C1997", + "Runtime exception: %s") + + object UNSUPPORTED_COMMAND extends MySQLErrorCode( + 1998, + "C1998", + "Unsupported command: %s") + + object UNKNOWN_EXCEPTION extends MySQLErrorCode( + 1999, + "C1999", + "Unknown exception: %s") + + object ER_DBACCESS_DENIED_ERROR extends MySQLErrorCode( + 1044, + "42000", + "Access denied for user '%s'@'%s' to database '%s'") + + object ER_ACCESS_DENIED_ERROR extends MySQLErrorCode( + 1045, + "28000", + "Access denied for user '%s'@'%s' (using password: %s)") + + object ER_NO_DB_ERROR extends MySQLErrorCode( + 1046, + "3D000", + "No database selected") + + object ER_BAD_DB_ERROR extends MySQLErrorCode( + 1049, + "42000", + "Unknown database '%s'") + + object ER_INTERNAL_ERROR extends MySQLErrorCode( + 1815, + "HY000", + "Internal error: %s") + + object ER_UNSUPPORTED_PS extends MySQLErrorCode( + 1295, + "HY000", + "This command is not supported in the prepared statement protocol yet") + + object ER_DB_CREATE_EXISTS_ERROR extends MySQLErrorCode( + 1007, + "HY000", + "Can't create database '%s'; database exists") + + object ER_DB_DROP_EXISTS_ERROR extends MySQLErrorCode( + 1008, + "HY000", + "Can't drop database '%s'; database doesn't exist") + + object ER_TABLE_EXISTS_ERROR extends MySQLErrorCode( + 1050, + "42S01", + "Table '%s' already exists") + + object ER_NO_SUCH_TABLE extends MySQLErrorCode( + 1146, + "42S02", + "Table '%s' doesn't exist") + + object ER_NOT_SUPPORTED_YET extends MySQLErrorCode( + 1235, + "42000", + "This version of Kyuubi-Server doesn't yet support this SQL. '%s'") +} diff --git a/kyuubi-server/src/main/scala/org/apache/kyuubi/server/mysql/constant/MySQLFieldDetailFlag.scala b/kyuubi-server/src/main/scala/org/apache/kyuubi/server/mysql/constant/MySQLFieldDetailFlag.scala new file mode 100644 index 000000000..c1694ae28 --- /dev/null +++ b/kyuubi-server/src/main/scala/org/apache/kyuubi/server/mysql/constant/MySQLFieldDetailFlag.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.server.mysql.constant + +sealed abstract class MySQLFieldDetailFlag(val value: Int) + +// https://mariadb.com/kb/en/library/resultset/#field-detail-flag +object MySQLFieldDetailFlag { + + object NOT_NULL extends MySQLFieldDetailFlag(0x00000001) + + object PRIMARY_KEY extends MySQLFieldDetailFlag(0x00000002) + + object UNIQUE_KEY extends MySQLFieldDetailFlag(0x00000004) + + object MULTIPLE_KEY extends MySQLFieldDetailFlag(0x00000008) + + object BLOB extends MySQLFieldDetailFlag(0x00000010) + + object UNSIGNED extends MySQLFieldDetailFlag(0x00000020) + + object ZEROFILL_FLAG extends MySQLFieldDetailFlag(0x00000040) + + object BINARY_COLLATION extends MySQLFieldDetailFlag(0x00000080) + + object ENUM extends MySQLFieldDetailFlag(0x00000100) + + object AUTO_INCREMENT extends MySQLFieldDetailFlag(0x00000200) + + object TIMESTAMP extends MySQLFieldDetailFlag(0x00000400) + + object SET extends MySQLFieldDetailFlag(0x00000800) + + object NO_DEFAULT_VALUE_FLAG extends MySQLFieldDetailFlag(0x00001000) + + object ON_UPDATE_NOW_FLAG extends MySQLFieldDetailFlag(0x00002000) + + object NUM_FLAG extends MySQLFieldDetailFlag(0x00008000) +} diff --git a/kyuubi-server/src/main/scala/org/apache/kyuubi/server/mysql/constant/MySQLServerDefines.scala b/kyuubi-server/src/main/scala/org/apache/kyuubi/server/mysql/constant/MySQLServerDefines.scala new file mode 100644 index 000000000..6c57254e2 --- /dev/null +++ b/kyuubi-server/src/main/scala/org/apache/kyuubi/server/mysql/constant/MySQLServerDefines.scala @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.server.mysql.constant + +import org.apache.kyuubi._ + +object MySQLServerDefines { + val PROTOCOL_VERSION = 0x0a + val CHARSET = 0x2d // utf8mb4_general_ci + val MYSQL_VERSION = "5.7.22" + val MYSQL_KYUUBI_SERVER_VERSION = s"$MYSQL_VERSION-Kyuubi-Server $KYUUBI_VERSION" + val KYUUBI_SERVER_DESCRIPTION = s"Apache Kyuubi (Incubating) v$KYUUBI_VERSION revision $REVISION" +} diff --git a/kyuubi-server/src/main/scala/org/apache/kyuubi/server/mysql/constant/MySQLStatusFlag.scala b/kyuubi-server/src/main/scala/org/apache/kyuubi/server/mysql/constant/MySQLStatusFlag.scala new file mode 100644 index 000000000..b3eb7c62b --- /dev/null +++ b/kyuubi-server/src/main/scala/org/apache/kyuubi/server/mysql/constant/MySQLStatusFlag.scala @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.server.mysql.constant + +sealed abstract class MySQLStatusFlag(val value: Int) + +// https://dev.mysql.com/doc/internals/en/status-flags.html#packet-Protocol::StatusFlags +object MySQLStatusFlag { + + object SERVER_STATUS_IN_TRANS extends MySQLStatusFlag(0x0001) + + object SERVER_STATUS_AUTOCOMMIT extends MySQLStatusFlag(0x0002) + + object SERVER_MORE_RESULTS_EXISTS extends MySQLStatusFlag(0x0008) + + object SERVER_STATUS_NO_GOOD_INDEX_USED extends MySQLStatusFlag(0x0010) + + object SERVER_STATUS_NO_INDEX_USED extends MySQLStatusFlag(0x0020) + + object SERVER_STATUS_CURSOR_EXISTS extends MySQLStatusFlag(0x0040) + + object SERVER_STATUS_LAST_ROW_SENT extends MySQLStatusFlag(0x0080) + + object SERVER_STATUS_DB_DROPPED extends MySQLStatusFlag(0x0100) + + object SERVER_STATUS_NO_BACKSLASH_ESCAPES extends MySQLStatusFlag(0x0200) + + object SERVER_STATUS_METADATA_CHANGED extends MySQLStatusFlag(0x0400) + + object SERVER_QUERY_WAS_SLOW extends MySQLStatusFlag(0x0800) + + object SERVER_PS_OUT_PARAMS extends MySQLStatusFlag(0x1000) + + object SERVER_STATUS_IN_TRANS_READONLY extends MySQLStatusFlag(0x2000) + + object SERVER_SESSION_STATE_CHANGED extends MySQLStatusFlag(0x4000) +} diff --git a/kyuubi-server/src/test/scala/org/apache/kyuubi/server/mysql/MySQLCodecHelper.scala b/kyuubi-server/src/test/scala/org/apache/kyuubi/server/mysql/MySQLCodecHelper.scala new file mode 100644 index 000000000..4b40be8ad --- /dev/null +++ b/kyuubi-server/src/test/scala/org/apache/kyuubi/server/mysql/MySQLCodecHelper.scala @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.server.mysql + +import io.netty.buffer.{ByteBuf, ByteBufUtil, Unpooled} + +import org.apache.kyuubi.KyuubiFunSuite + +trait MySQLCodecHelper extends KyuubiFunSuite { + + def decodeHex(hexDump: String): ByteBuf = { + val compact = hexDump.replaceAll("(?s)\\s", "") + val bytes = ByteBufUtil.decodeHexDump(compact) + Unpooled.copiedBuffer(bytes) + } + + def verifyDecode[T <: MySQLPacket]( + decoder: SupportsDecode[T], + payload: ByteBuf, + expected: T + )(assertion: (T, T) => Unit): Unit = { + val decoded = decoder.decode(payload) + assertion(decoded, expected) + } + + def verifyEncode(expected: ByteBuf, packet: SupportsEncode): Unit = { + val encoded = Unpooled.buffer() + packet.encode(encoded) + assert(encoded === expected) + } +} diff --git a/kyuubi-server/src/test/scala/org/apache/kyuubi/server/mysql/MySQLCommandPacketSuite.scala b/kyuubi-server/src/test/scala/org/apache/kyuubi/server/mysql/MySQLCommandPacketSuite.scala new file mode 100644 index 000000000..3cd6f2674 --- /dev/null +++ b/kyuubi-server/src/test/scala/org/apache/kyuubi/server/mysql/MySQLCommandPacketSuite.scala @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.server.mysql + +import org.apache.kyuubi.KyuubiFunSuite + +class MySQLCommandPacketSuite extends KyuubiFunSuite with MySQLCodecHelper { + + test("decode MySQLComInitDbPacket") { + val payload = decodeHex("6b 79 75 75 62 69") + val expected = MySQLComInitDbPacket("kyuubi") + verifyDecode(MySQLComInitDbPacket, payload, expected) { (decoded, expected) => + assert(decoded === expected) + } + } + + test("decode MySQLComFieldListPacket") { + val payload = decodeHex("6b 79 75 75 62 69 00 2a") + val expected = MySQLComFieldListPacket("kyuubi", "*") + verifyDecode(MySQLComFieldListPacket, payload, expected) { (decoded, expected) => + assert(decoded === expected) + } + } + + test("decode MySQLComQueryPacket") { + val payload = decodeHex( + """73 65 6c 65 63 74 20 6b 79 75 75 62 69 5f 76 65 + |72 73 69 6f 6e 28 29 + |""".stripMargin) + val expected = MySQLComQueryPacket("select kyuubi_version()") + verifyDecode(MySQLComQueryPacket, payload, expected) { (decoded, expected) => + assert(decoded === expected) + } + } +} diff --git a/kyuubi-server/src/test/scala/org/apache/kyuubi/server/mysql/MySQLDataPacketSuite.scala b/kyuubi-server/src/test/scala/org/apache/kyuubi/server/mysql/MySQLDataPacketSuite.scala new file mode 100644 index 000000000..2c23b310f --- /dev/null +++ b/kyuubi-server/src/test/scala/org/apache/kyuubi/server/mysql/MySQLDataPacketSuite.scala @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.server.mysql + +import org.apache.kyuubi.KyuubiFunSuite +import org.apache.kyuubi.server.mysql.constant.MySQLDataType + +class MySQLDataPacketSuite extends KyuubiFunSuite with MySQLCodecHelper { + + test("encode MySQLFieldCountPacket") { + val packet = MySQLFieldCountPacket(1, 1) + val expected = decodeHex("01") + verifyEncode(expected, packet) + } + + test("encode MySQLColumnDefinition41Packet") { + val packet = MySQLColumnDefinition41Packet(1, 0, "UDF()", 100, MySQLDataType.VAR_STRING, 0) + val expected = decodeHex( + """00 00 00 00 05 55 44 46 28 29 00 0c 2d 00 64 00 + |00 00 fd 00 00 00 00 00 + |""".stripMargin) + verifyEncode(expected, packet) + } + + test("encode MySQLTextResultSetRowPacket") { + val packet = MySQLTextResultSetRowPacket(2, Seq("1.4.0-SNAPSHOT")) + val expected = decodeHex("0e 31 2e 34 2e 30 2d 53 4e 41 50 53 48 4f 54") + verifyEncode(expected, packet) + } +} diff --git a/kyuubi-server/src/test/scala/org/apache/kyuubi/server/mysql/MySQLGenericPacketSuite.scala b/kyuubi-server/src/test/scala/org/apache/kyuubi/server/mysql/MySQLGenericPacketSuite.scala new file mode 100644 index 000000000..5a463f364 --- /dev/null +++ b/kyuubi-server/src/test/scala/org/apache/kyuubi/server/mysql/MySQLGenericPacketSuite.scala @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.server.mysql + +import org.apache.kyuubi.KyuubiFunSuite +import org.apache.kyuubi.server.mysql.constant.MySQLErrorCode + +class MySQLGenericPacketSuite extends KyuubiFunSuite with MySQLCodecHelper { + + test("encode MySQLOKPacket") { + val packet = MySQLOKPacket(1, 2, 3) + val expected = decodeHex("00 02 03 02 00 00 00") + verifyEncode(expected, packet) + } + + test("encode MySQLErrPacket") { + val packet = MySQLErrPacket(1, MySQLErrorCode.TOO_MANY_CONNECTIONS_EXCEPTION) + val expected = decodeHex( + """ff 10 04 23 30 38 30 30 34 54 6f 6f 20 6d 61 6e + |79 20 63 6f 6e 6e 65 63 74 69 6f 6e 73 + |""".stripMargin) + verifyEncode(expected, packet) + } + + test("encode MySQLEofPacket") { + val packet = MySQLEofPacket() + val expected = decodeHex("fe 00 00 02 00") + verifyEncode(expected, packet) + } +} diff --git a/kyuubi-server/src/test/scala/org/apache/kyuubi/server/mysql/authentication/MySQLAuthPacketSuite.scala b/kyuubi-server/src/test/scala/org/apache/kyuubi/server/mysql/authentication/MySQLAuthPacketSuite.scala new file mode 100644 index 000000000..7379eff03 --- /dev/null +++ b/kyuubi-server/src/test/scala/org/apache/kyuubi/server/mysql/authentication/MySQLAuthPacketSuite.scala @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.server.mysql.authentication + +import java.util + +import org.apache.kyuubi.KyuubiFunSuite +import org.apache.kyuubi.server.mysql.MySQLCodecHelper + +class MySQLAuthPacketSuite extends KyuubiFunSuite with MySQLCodecHelper { + + private val authPluginData = { + val part1 = decodeHex("77 37 34 35 45 51 55 65").array + val part2 = decodeHex("69 44 57 32 44 44 33 4e 6d 36 69 74").array + MySQLNativePassword.PluginData(part1, part2) + } + + test("encode MySQLHandshakePacket") { + val packet = MySQLHandshakePacket(2, authPluginData) + val expected = decodeHex( + """0a 35 2e 37 2e 32 32 2d 4b 79 75 75 62 69 2d 53 + |65 72 76 65 72 20 31 2e 34 2e 30 2d 53 4e 41 50 + |53 48 4f 54 00 02 00 00 00 77 37 34 35 45 51 55 + |65 00 4f b7 2d 02 00 08 00 15 00 00 00 00 00 00 + |00 00 00 00 69 44 57 32 44 44 33 4e 6d 36 69 74 + |00 6d 79 73 71 6c 5f 6e 61 74 69 76 65 5f 70 61 + |73 73 77 6f 72 64 00 + |""".stripMargin) + verifyEncode(expected, packet) + } + + test("decode MySQLHandshakeResponse41Packet") { + val payload = decodeHex( + """01 85 a6 ff 19 00 00 00 01 2d 00 00 00 00 00 00 + |00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 + |00 63 68 65 6e 67 70 61 6e 00 14 15 c6 82 8e 53 + |67 20 3a 44 f3 d1 3e 62 f8 2d 20 38 3c 75 94 6d + |79 73 71 6c 5f 6e 61 74 69 76 65 5f 70 61 73 73 + |77 6f 72 64 00 + |""".stripMargin) + val expected = MySQLHandshakeResponse41Packet( + 1, + 0x19ffa685, + 16777216, + 0x2d, + "chengpan", + decodeHex( + """15 c6 82 8e 53 67 20 3a 44 f3 d1 3e 62 f8 2d 20 + |38 3c 75 94 + |""".stripMargin).array, + null, + "mysql_native_password") + verifyDecode(MySQLHandshakeResponse41Packet, payload, expected) { (decoded, expected) => + assert(decoded.sequenceId === expected.sequenceId) + assert(decoded.capabilityFlags === expected.capabilityFlags) + assert(decoded.maxPacketSize === expected.maxPacketSize) + assert(decoded.characterSet === expected.characterSet) + assert(decoded.username === expected.username) + assert(util.Arrays.equals(decoded.authResponse, expected.authResponse)) + assert(decoded.database === expected.database) + assert(decoded.authPluginName === expected.authPluginName) + } + } + + test("encode MySQLAuthSwitchRequestPacket") { + val packet = MySQLAuthSwitchRequestPacket( + 1, + MySQLAuthenticationMethod.NATIVE_PASSWORD.method, + authPluginData) + val expected = decodeHex( + """fe 6d 79 73 71 6c 5f 6e 61 74 69 76 65 5f 70 61 + |73 73 77 6f 72 64 00 77 37 34 35 45 51 55 65 69 + |44 57 32 44 44 33 4e 6d 36 69 74 00 + |""".stripMargin) + verifyEncode(expected, packet) + } + + test("decode MySQLAuthSwitchResponsePacket") { + val payloadHex = decodeHex( + """14 00 00 03 f4 17 96 1f 79 f3 ac 10 0b da a6 b3 + |b5 c2 0e ab 59 85 ff b8 + |""".stripMargin) + val expectedAuthPluginResponse = decodeHex( + """00 00 03 f4 17 96 1f 79 f3 ac 10 0b da a6 b3 b5 + |c2 0e ab 59 85 ff b8 + |""".stripMargin).array + + verifyDecode( + MySQLAuthSwitchResponsePacket, + payloadHex, + MySQLAuthSwitchResponsePacket(20, expectedAuthPluginResponse)) { (decoded, expected) => + assert(decoded.sequenceId === expected.sequenceId) + assert(util.Arrays.equals(decoded.authPluginResponse, expected.authPluginResponse)) + } + } +}